Hellas

Deterministic transcendental functions

Background

We want our inference process to be deterministic, meaning that for a given model checkpoint, the same input should result in a reproducible output, identical across all platforms. This requires careful control of what optimizations the CPU/Metal/CUDA/HIP compilers are allowed to make, what platform APIs and primitives can be relied on and how operations and reductions are ordered during execution. Due to floating point rounding the latter is a notorious source of nondeterminism, making FP addition be non-associative, but it is important even in pure low precision integer arithmetic due to saturation and clamping.

Since all current realistic workloads contain floating point operations (even quantized models keep a few sensitive layers in high precision), we will not discuss integer arithmetic here.

For deterministic execution all the component building blocks of the computation graph must themselves be deterministic. While in a modern transformer model the majority of these are matmuls, there are a few blocks using transcendental functions too:

  • softmax relies on exp and is present in most self-attention and output layers
  • positional embeddings use sin and cos
  • the softplus function in Mamba-like layers is defined as ln(1+e^x)

The problem is these functions are not guaranteed to produce bitwise identical results on different platforms. This snippet will likely write False for all four functions.

import torch

# Create tensor on CPU
a = torch.rand(200)

for f in torch.sin, torch.cos, torch.log, torch.exp:
    print(torch.equal(f(a), f(a.to('cuda:0')).to('cpu')))

This post will look at how to make these transcendental functions have reproducible outputs.

Floating point notions

IEEE 754 - The IEEE Standard for Floating-Point Arithmetic, originally published in 1985, most recently updated in 2019. It defines the floating point formats, operations, rounding modes and exceptions that a conforming implementation should support.

Floating point formats - the different precision bit representations on floating point numbers. IEEE 754 describes binary16, binary32, binary64, binary128 and binary256 along with a few decimal formats. binary32 and binary64 used to be called single and double in the original text and correspond to the float and double C types. For deep learning only binary16 and binary32 are relevant.

Rounding modes - the standard describes five rounding modes: toward −∞, toward ∞, toward 0, round to nearest ties away from 0, round to nearest ties to even. The latter is the most commonly used in deep learning, for determinism one should pick this one and stick to it.

Correctly rounded - a floating point operation that behaves as if it computed the exact real result and then rounded it according to the chosen rounding mode. This makes it deterministic.

Faithfully rounded - a floating point operation where the returned result is either one of the two floating-point numbers neighbouring the exact result. This makes it nondeterministic but the implementation is usually faster than for correctly rounded alternatives.

Transcendental functions - functions that are not algebraic, so cannot be written as a polynomial equation. They can be trigonometric, exponential, hyperbolic and their inverses, for example sin, cos, exp, ln, tanh.

Elementary functions - in floating point and approximation theory literature these refer to the transcendental and also some algebraic functions like sqrt(1-x^2), basic functions that are not primitive arithmetic operations but are commonly used in numerical computations.

Required operations in IEEE 754: +, -, *, /, sqrt, and starting with the the 2008 update FMA (fused-multiply-add). The standard requires these to be correctly rounded.

Recommended operations in IEEE 754: most of the elementary functions are recommended but not required to be part of an IEEE 754 compliant implementation.

Implementing transcendental functions

Since these are not required operations in IEEE 754, they generally do not have correctly rounded implementations in most libraries. If that were the case, determinism would be a solved problem for transcendentals.

These functions are usually provided by the platform libm or its equivalents - they are independent implementations part of the Linux glibc, LLVM libc and the CUDA and HIP runtime libraries, tailored to specific platforms with different trade-offs between performance and accuracy. Their outputs' bitwise representations are not identical due to the approximation method used, rounding choice and other implementation details. The CUDA API even has two variants, the very fast intrinsics like __sinf and __cosf that translate to hardware instructions run on the GPU's Special Function Unit, and the slower but more accurate APIs building on these intrinsics.

We need reproducible, reasonably performant and reasonably accurate implementations of transcendental functions for deep learning workloads. The reproducibility requirement rules out relying on platform APIs. Using a correctly rounded implementation from projects like the Core-Math, RLibm or SLEEF as a starting point is one possibility, but these are more generic and complex than necessary for deep learning, target high precision scientific computing too and more often than not need porting to non-CPU platforms. The other option left is to implement the functions from scratch, specifically for our use case.

Polynomial approximation

Transcendental functions do not have exact formulas in terms of primitive operations, so they are usually implemented using approximation methods like polynomial expansions and/or lookup tables. According to the Weierstrass approximation theorem any continuous function can be approximated to an arbitrary precision using a polynomial, the higher the degree the better the approximation. Briefly, to approximate a given function for a given input range, a function-specific polynomial is picked according to the desired accuracy, then each call to the function evaluates this fixed polynomial on the input. For example a very straightforward and naive implementation for sin and cos would be the sum of the first few terms of their respective Taylor series (also known as Maclaurin series in this particular case of looking at the derivatives of sin at 0):

\[ \sin(x) = \sum_{n=0}^{\infty} \frac{(-1)^n x^{2n+1}}{(2n+1)!} \]

and

\[ \cos(x) = \sum_{n=0}^{\infty} \frac{(-1)^n x^{2n}}{(2n)!} \]

import math
import torch

def sin_taylor(x):
    return x - x**3/math.factorial(3) + x**5/math.factorial(5) - x**7/math.factorial(7)

def cos_taylor(x):
    return 1 - x**2/math.factorial(2) + x**4/math.factorial(4) - x**6/math.factorial(6)

inputs = torch.linspace(-math.pi/4, math.pi/4, 100)

print(torch.allclose(inputs.sin(), sin_taylor(inputs)))
print(torch.allclose(inputs.cos(), cos_taylor(inputs)))
True
True

If the input interval is small and close to zero, such an approximation can be acceptable, but in the script above changing the interval to [-π/2, π/2] will cause divergence and extra terms need to be added from the Taylor expansion to keep the errors under control. In production implementations that have to deal with large input ranges another approach is needed, one that involves better polynomials and mapping the inputs to a smaller range.

Minimax polynomial

Because the Taylor polynomial is increasingly inaccurate as the input moves away from zero, even as more terms of the expansion are used, a so-called minimax polynomial is the usual choice. At the expense of a bit more divergence close to zero, it provides uniform accuracy over the entire input range. It is called minimax because it minimizes the maximum error over the input range. Where the Taylor polynomial approximation error shoots up as we move away from zero, the minimax polynomial error is a small bounded uniform sine-like oscillation. The theory behind computing the such a polynomial involves Chebyshev nodes, Lagrange interpolation and the Remez algorithm, which is the standard iterative method of producing the coefficients.

Depending on the input range and the accuracy requirements, there are multiple possible minimax polynomials for a given function. The coefficients can be lifted from existing production libraries or computed from scratch using the Remez algorithm. One popular implementation is in the Sollya project, which is both a library and small scripting language for safe floating-point code development.

Here are invocations of Sollya from the command line to compute the minimax polynomials for sin(x) and cos(x):

echo "fpminimax(sin(x), [|1,3,5,7|], [|SG...|], [-pi/4, pi/4], absolute);" | sollya
x * (1 + x^2 * (-0.166666507720947265625 + x^2 * (8.331983350217342376708984375e-3 + x^2 * (-1.94961365195922553539276123046875e-4))))
echo "fpminimax(cos(x), [|0,2,4,6|], [|SG...|], [-pi/4, pi/4], absolute);" | sollya
1 + x^2 * (-0.49999892711639404296875 + x^2 * (4.16561998426914215087890625e-2 + x^2 * (-1.35968066751956939697265625e-3)))

We passed the [-π/4, π/4] range on which to approximate, and the degrees of the polynomials to use. Since sin is an odd function (f(-x) = -f(x)), and cos is an even function (f(-x) = f(x)), their minimax polynomials also have non-zero coefficients only for odd and even powers of x respectively, but they are slightly different from the corresponding Taylor coefficients.

The output of Sollya is an expression that can be used to evaluate the polynomial for any input value on the given interval. It is of the form \[C0 + x * (C1 + x* (C2 + x * (...)))\] so that the polynomial can be evaluated with fewer multiplications than the naive Python Taylor expressions above. This is known as Horner's scheme. There's a parallelized version known as Estrin's scheme, but for such low degree polynomials as used in most transcendental function implementations with only 4-5 terms, it is rarely justified. On the other hand one frequent optimization is using FMA (fused multiply-add) instead of explicit multiplication and addition operations because the fused operation will only require a single rounding operation instead of two, yielding better precision. FMA is present on most modern CPUs and GPUs so it is recommended to be used consistently in a deterministic implementation.

Here are examples of sin and cos approximations using the minimax polynomials computed above, evaluated using Horner's scheme, one using explicit multiplication and addition operations and the other using FMA:

def sin_approx(x):
    """
    Approximates sin(x) using a minimax polynomial of degree 7 on the reduced interval [-π/4, π/4].
    """
    C1 = 1
    C3 = -0.166666507720947265625
    C5 = 8.331983350217342376708984375e-3
    C7 = -1.94961365195922553539276123046875e-4

    # Horner's scheme for sin
    return x * (C1 + x*x*(C3 + x*x*(C5 + x*x*C7)))

def cos_approx(x):
    """
    Approximates cos(x) using a minimax polynomial of degree 6 on the reduced interval [-π/4, π/4].
    """

    C0 = 1
    C2 = -0.49999892711639404296875
    C4 = 4.16561998426914215087890625e-2
    C6 = -1.35968066751956939697265625e-3

    # Horner's scheme for cos using FMA
    # return C0 + x*x*(C2 + x*x*(C4 + x*x*C6))
    x2 = x * x
    c = math.fma(x2, C6, C4)
    c = math.fma(x2, c, C2)
    c = math.fma(x2, c, C0)

    return c

Range reduction

Even with minimax polynomials it is impractical to approximate a function over a large input range. To maintain accuracy the degree of the polynomial needs to increase as the range increases, slowing down computation and causing representation issues if coefficients become too large or small. One way around this is piecewise approximation, where the input range is divided into smaller intervals and a different polynomial is used for each one, but this makes the code more complicated.

The standard approach is range reduction:

  • find a small fixed range for the input values where the approximation of the function we are interested in is good enough with small degree polynomials
  • find an algebraic relation that expresses the function using itself called on only input values from this small fixed range
  • implement the function by translating inputs to the small range, approximate the output on that range, and compute the final result based on this reduced approximation by doing an inverse translation

Since they are related periodic functions, sin(x) and cos(x) can be expressed as ±sin(xr) or ±cos(xr) where xr is in the [-π/4, π/4] range.

Let \(x = q\frac{\pi}{2} + r\) where \(r \in \left[-\frac{\pi}{4}, \frac{\pi}{4}\right]\) and \(q \in \mathbb{Z}\). Then, depending on \(q \bmod 4\):

\[sin(x) = \begin{cases} \sin(r) & q \equiv 0 \pmod{4} \\ \cos(r) & q \equiv 1 \pmod{4} \\ -\sin(r) & q \equiv 2 \pmod{4} \\ -\cos(r) & q \equiv 3 \pmod{4} \end{cases}\]

def reduced_sincos(x):
    q = (x / (math.pi/2)).round()
    # direct subtraction causing catastrophic cancellation
    # xr = x - q * (math.pi/2)
    # use Cody-Waite subtraction instead
    xr = cody_waite_subtract(x, q)
    sin = sin_approx(xr)
    cos = cos_approx(xr)
    match q:
        case 0:
            return sin, cos
        case 1:
            return cos, -sin
        case 2:
            return -sin, -cos
        case 3:
            return -cos, sin

The naive reduction xr = x - q * (math.pi/2) calculation can suffer from catastrophic cancellation so most implementations use the Cody-Waite reduction: π/2 is expressed as a sum of constants of different magnitudes, each being exactly representable, and instead of a single subtraction, these constants are subtracted individually.

def cody_waite_subtract(x, q):
    # P0 + P1 = π/2
    P0 = float.fromhex("0x1.92p+0")
    P1 = float.fromhex("0x1.fb54442d18p-12")
    return (x - q*P0) - q*P1

It is better to express constants as hexadecimal or binary literals to avoid any possible ambiguity in parsing and bit representation of decimal literals. This is a generic approach, used regardless of the reduction interval - for example [-π/2, π/2] is used for approximating tan. When working with double precision there are variants of expressing the sum using 3 or 4 constants instead of just the 2 here.

Approximating exponential and logarithm

The same principles apply for log and exp as for the trigonometric functions: find a reduction interval and a formula to map arbitrarily large inputs to that interval, then use polynomial approximation on it. Unlike for the periodic trigonometric functions where these approximated values can be readily used, here we need a reconstruction step to map the values on the small interval back to the full range.

Exponential

For exp we reduce to the interval [0, log(2)] and use a minimax polynomial where the coefficients are computed using Sollya. Any exp value can be expressed as \[ e^x = 2^k e^{x - k \log 2} \] where k is the integer part of x/log(2).

echo "fpminimax(exp(x), [|0,1,2,3|], [|SG...|], [0, log(2)], absolute);" | sollya
0.9998929500579833984375 + x * (1.0047757625579833984375 + x * (0.4669305980205535888671875 + x * 0.23783318698406219482421875))
def reduced_exp(x):
    k = round(x/math.log(2))
    xr = x - k*math.log(2)
    C0 = 0.9998929500579833984375
    C1 = 1.0047757625579833984375
    C2 = 0.4669305980205535888671875
    C3 = 0.23783318698406219482421875
    xr = C0 + xr * (C1 + xr * (C2 + xr * C3))
    return  xr * (2**k)

Logarithm

The natural logarithm of a number can be expressed using the mantissa and exponent of its float representation.

\[ \ln(x) = \ln(m) + e \ln(2) \]

If we use frexp, the mantissa is normalized to [0.5, 1), so that is the reduction range we look for minimax coefficients on:

echo "fpminimax(log(x), [|0,1,2,3|], [|SG...|], [0.5, 1], absolute);" | sollya
-2.1859228610992431640625 + x * (4.22526264190673828125 + x * (-2.9164140224456787109375 + x * 0.877515852451324462890625))
def reduced_log(x):
    assert x > 0
    m,e = math.frexp(x)
    C0 = -2.1859228610992431640625
    C1 = 4.22526264190673828125
    C2 = -2.9164140224456787109375
    C3 = 0.877515852451324462890625
    rl = C0 + m * (C1 + m * (C2 + m * C3))

    return rl + e*math.log(2)

Zeros, NaNs and infinities

These edge values need explicit handling because NaN representation can vary across platforms. We should pick a valid NaN bit pattern of the several available, and use it consistently. The checks for NaN and infinity should come first, before any other computation, so range reduction and approximation works on valid inputs only. The above python snippets do not include these checks.

  • For sin and cos, Inf input should be treated as NaN and return NaN.
  • For sin, if the input is ±0 return the same sign 0.
  • For cos and exp, if the input is ±0 return 1 directly.
  • For exp, if the input is +Inf return +Inf, and if the input is -Inf return +0.

Deep learning specific considerations

It makes sense to also implement a function that computes both sine and cosine at the same time. They share the reduction stage and often both values are needed for the same input anyway, as in the case of positional embeddings.

Plain table lookup without polynomial approximation is a good option when the range of possible inputs is known to be small and fixed such as in FP8 or a subset of BF16, although these are not really used for positional embeddings due to loss of accuracy at longer contexts.

The Cody-Waite method cannot very accurately compute the reduced range for sin and cosine for very large inputs (> 2^20), and in those cases Payne-Hanek reduction is used instead, but for positional embeddings we're fine with the simpler approximation.

We can get away with using interval reduction and Taylor series expansion of four terms for sin, cos and exp for running LLMs, but since minimax polynomials can get the same or better accuracy with fewer terms, we prefer to use them instead.

Conclusion

For determinism we must pick a rounding mode, decide whether or not to use FMA, the range reduction method, the polynomial approximation method and coefficients and the evaluation method and implement the algorithm in the same way for all target platforms. These choices should be made depending on the input range, the accuracy requirements and even benchmarking various options for speed and compliance. There is a wide range of options for each of them but it is safe to just use FMA, round to nearest ties to even, a minimax polynomial generated by Sollya and Horner evaluation.

References

Nvidia article on floating point

Correctly Rounded Evaluation of a Function: Why, How, and at What Cost?

Elementary Functions: Algorithms and Implementation, a book by Jean-Michel Muller