Skip to content

Commit

Permalink
Add erfinv and erfcinv for Float16 and generalize logerfc and…
Browse files Browse the repository at this point in the history
… `logerfcx` (#372)

* Add `erfinv` and `erfcinv` for `Float16`

* Generalize `logerfc` and `logerfcx`

* Add tests

* Update version number

* Update test/erf.jl

* Fix test

* Simplify branch for `abs(x) >= 1` in `_erfinv`

* Fix incomplete function

* Simplify implementation

* Apply suggested alternative

---------

Co-authored-by: Viral B. Shah <[email protected]>
  • Loading branch information
devmotion and ViralBShah authored May 7, 2024
1 parent e08ff8d commit 124915f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 80 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SpecialFunctions"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.3.1"
version = "2.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
89 changes: 65 additions & 24 deletions src/erf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,10 @@ erfinv(x::Real) = _erfinv(float(x))

function _erfinv(x::Float64)
a = abs(x)
if a >= 1.0
if x == 1.0
return Inf
elseif x == -1.0
return -Inf
end
if a > 1.0
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
elseif a == 1.0
return copysign(Inf, x)
elseif a <= 0.75 # Table 17 in Blair et al.
t = x*x - 0.5625
return x * @horner(t, 0.16030_49558_44066_229311e2,
Expand Down Expand Up @@ -321,13 +318,10 @@ end

function _erfinv(x::Float32)
a = abs(x)
if a >= 1.0f0
if x == 1.0f0
return Inf32
elseif x == -1.0f0
return -Inf32
end
if a > 1f0
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
elseif a == 1f0
return copysign(Inf32, x)
elseif a <= 0.75f0 # Table 10 in Blair et al.
t = x*x - 0.5625f0
return x * @horner(t, -0.13095_99674_22f2,
Expand Down Expand Up @@ -362,6 +356,42 @@ function _erfinv(x::Float32)
end
end

function _erfinv(x::Float16)
a = abs(x)
if a > Float16(1)
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
elseif a == Float16(1)
return copysign(Inf16, x)
else
# Perform calculations with `Float32`
x32 = Float32(x)
a32 = Float32(a)
if a32 <= 0.75f0
# Simpler and more accurate alternative to Table 7 in Blair et al.
# Ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/372#discussion_r1592832735
t = muladd(-6.73815f1, x32, 1f0) / muladd(-4.18798f0, x32, 4.54263f0)
y = copysign(muladd(0.88622695f0, x32, t), x32)
elseif a32 <= 0.9375f0 # Table 26 in Blair et al.
t = x32^2 - 0.87890625f0
y = x32 * @horner(t, 0.10178_950f1,
-0.32827_601f1) /
@horner(t, 0.72455_99f0,
-0.33871_553f1,
0.1f1)
else
# Simpler alternative to Table 47 in Blair et al.
# because of the reduced accuracy requirement
# (it turns out that this branch only covers 128 values).
# Note that the use of log(1-x) rather than log1p is intentional since it will be
# slightly faster and 1-x is exact.
# Ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/372#discussion_r1592710586
t = sqrt(-log(1-a32))
y = copysign(@horner(t, -0.429159f0, 1.04868f0), x32)
end
return Float16(y)
end
end

function _erfinv(y::BigFloat)
xfloat = erfinv(Float64(y))
if isfinite(xfloat)
Expand Down Expand Up @@ -482,6 +512,25 @@ function _erfcinv(y::Float32)
end
end

function _erfcinv(y::Float16)
if y > Float16(0.0625)
return erfinv(Float16(1) - y)
elseif y <= Float16(0)
if y == Float16(0)
return Inf16
end
throw(DomainError(y, "`y` must be nonnegative."))
else # Table 47 in Blair et al.
t = 1.0f0 / sqrt(-log(Float32(y)))
x = @horner(t, 0.98650_088f0,
0.92601_777f0) /
(t * @horner(t, 0.98424_719f0,
0.10074_7432f0,
0.1f0))
return Float16(x)
end
end

function _erfcinv(y::BigFloat)
yfloat = Float64(y)
xfloat = erfcinv(yfloat)
Expand Down Expand Up @@ -526,13 +575,9 @@ See also: [`erfcx(x)`](@ref erfcx).
# Implementation
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
logerfc(x::Real) = _logerfc(float(x))

function _logerfc(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x > 0.0
function logerfc(x::Real)
if x > zero(x)
return log(erfcx(x)) - x^2
else
return log(erfc(x))
Expand All @@ -557,13 +602,9 @@ See also: [`erfcx(x)`](@ref erfcx).
# Implementation
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
logerfcx(x::Real) = _logerfcx(float(x))

function _logerfcx(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x < 0.0
function logerfcx(x::Real)
if x < zero(x)
return log(erfc(x)) + x^2
else
return log(erfcx(x))
Expand Down
93 changes: 38 additions & 55 deletions test/erf.jl
Original file line number Diff line number Diff line change
@@ -1,65 +1,48 @@
@testset "error functions" begin
@testset "real argument" begin
@test erf(Float16(1)) 0.84270079294971486934 rtol=2*eps(Float16)
@test erf(Float32(1)) 0.84270079294971486934 rtol=2*eps(Float32)
@test erf(Float64(1)) 0.84270079294971486934 rtol=2*eps(Float64)

@test erfc(Float16(1)) 0.15729920705028513066 rtol=2*eps(Float16)
@test erfc(Float32(1)) 0.15729920705028513066 rtol=2*eps(Float32)
@test erfc(Float64(1)) 0.15729920705028513066 rtol=2*eps(Float64)

@test erfcx(Float16(1)) 0.42758357615580700442 rtol=2*eps(Float16)
@test erfcx(Float32(1)) 0.42758357615580700442 rtol=2*eps(Float32)
@test erfcx(Float64(1)) 0.42758357615580700442 rtol=2*eps(Float64)

@test_throws MethodError logerfc(Float16(1))
@test_throws MethodError logerfc(Float16(-1))
@test logerfc(Float32(-100)) 0.6931471805599453 rtol=2*eps(Float32)
@test logerfc(Float64(-100)) 0.6931471805599453 rtol=2*eps(Float64)
@test logerfc(Float32(1000)) -1.0000074801207219e6 rtol=2*eps(Float32)
@test logerfc(Float64(1000)) -1.0000074801207219e6 rtol=2*eps(Float64)
@test logerfc(1000) -1.0000074801207219e6 rtol=2*eps(Float32)
@test logerfc(Float32(10000)) log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float32)
@test logerfc(Float64(10000)) log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float64)

@test_throws MethodError logerfcx(Float16(1))
@test_throws MethodError logerfcx(Float16(-1))
@test iszero(logerfcx(0))
@test logerfcx(Float32(1)) -0.849605509933248248576017509499 rtol=2eps(Float32)
@test logerfcx(Float64(1)) -0.849605509933248248576017509499 rtol=2eps(Float32)
@test logerfcx(Float32(-1)) 1.61123231767807049464268192445 rtol=2eps(Float32)
@test logerfcx(Float64(-1)) 1.61123231767807049464268192445 rtol=2eps(Float32)
@test logerfcx(Float32(-100)) 10000.6931471805599453094172321 rtol=2eps(Float32)
@test logerfcx(Float64(-100)) 10000.6931471805599453094172321 rtol=2eps(Float64)
@test logerfcx(Float32(100)) -5.17758512266433257046678208395 rtol=2eps(Float32)
@test logerfcx(Float64(100)) -5.17758512266433257046678208395 rtol=2eps(Float64)
@test logerfcx(Float32(-1000)) 1.00000069314718055994530941723e6 rtol=2eps(Float32)
@test logerfcx(Float64(-1000)) 1.00000069314718055994530941723e6 rtol=2eps(Float64)
@test logerfcx(Float32(1000)) -7.48012072190621214066734919080 rtol=2eps(Float32)
@test logerfcx(Float64(1000)) -7.48012072190621214066734919080 rtol=2eps(Float64)

@test erfi(Float16(1)) 1.6504257587975428760 rtol=2*eps(Float16)
@test erfi(Float32(1)) 1.6504257587975428760 rtol=2*eps(Float32)
@test erfi(Float64(1)) 1.6504257587975428760 rtol=2*eps(Float64)
for T in (Float16, Float32, Float64)
@test @inferred(erf(T(1))) isa T
@test erf(T(1)) T(0.84270079294971486934) rtol=2*eps(T)

@test erfinv(Integer(0)) == 0 == erfinv(0//1)
@test_throws MethodError erfinv(Float16(1))
@test erfinv(Float32(0.84270079294971486934)) 1 rtol=2*eps(Float32)
@test erfinv(Float64(0.84270079294971486934)) 1 rtol=2*eps(Float64)
@test @inferred(erfc(T(1))) isa T
@test erfc(T(1)) T(0.15729920705028513066) rtol=2*eps(T)

@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
@test_throws MethodError erfcinv(Float16(1))
@test erfcinv(Float32(0.15729920705028513066)) 1 rtol=2*eps(Float32)
@test erfcinv(Float64(0.15729920705028513066)) 1 rtol=2*eps(Float64)
@test @inferred(erfcx(T(1))) isa T
@test erfcx(T(1)) T(0.42758357615580700442) rtol=2*eps(T)

@test @inferred(logerfc(T(1))) isa T
@test logerfc(T(-100)) T(0.6931471805599453) rtol=2*eps(T)
@test logerfc(T(1000)) T(-1.0000074801207219e6) rtol=2*eps(T)
@test logerfc(T(10000)) T(log(erfc(BigFloat(10000, precision=100)))) rtol=2*eps(T)

@test @inferred(logerfcx(T(1))) isa T
@test logerfcx(T(1)) T(-0.849605509933248248576017509499) rtol=2eps(T)
@test logerfcx(T(-1)) T(1.61123231767807049464268192445) rtol=2eps(T)
@test logerfcx(T(-100)) T(10000.6931471805599453094172321) rtol=2eps(T)
@test logerfcx(T(100)) T(-5.17758512266433257046678208395) rtol=2eps(T)
@test logerfcx(T(-1000)) T(1.00000069314718055994530941723e6) rtol=2eps(T)
@test logerfcx(T(1000)) T(-7.48012072190621214066734919080) rtol=2eps(T)

@test @inferred(erfi(T(1))) isa T
@test erfi(T(1)) T(1.6504257587975428760) rtol=2*eps(T)

@test @inferred(erfinv(T(1))) isa T
@test erfinv(T(0.84270079294971486934)) 1 rtol=2*eps(T)

@test dawson(Float16(1)) 0.53807950691276841914 rtol=2*eps(Float16)
@test dawson(Float32(1)) 0.53807950691276841914 rtol=2*eps(Float32)
@test dawson(Float64(1)) 0.53807950691276841914 rtol=2*eps(Float64)
@test @inferred(erfcinv(T(1))) isa T
@test erfcinv(T(0.15729920705028513066)) 1 rtol=2*eps(T)

@test @inferred(dawson(T(1))) isa T
@test dawson(T(1)) T(0.53807950691276841914) rtol=2*eps(T)

@test @inferred(faddeeva(T(1))) isa Complex{T}
@test faddeeva(T(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(T)
end

@test logerfc(1000) -1.0000074801207219e6 rtol=2*eps(Float32)
@test erfinv(Integer(0)) == 0 == erfinv(0//1)
@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
@test faddeeva(0) == faddeeva(0//1) == 1
@test faddeeva(Float16(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float16)
@test faddeeva(Float32(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float32)
@test faddeeva(Float64(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float64)
end

@testset "complex arguments" begin
Expand Down

2 comments on commit 124915f

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/106370

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.4.0 -m "<description of version>" 124915fce203925b69fa1a295a2ab3025cbe3f3c
git push origin v2.4.0

Please sign in to comment.