Faux Attention

One might contend that attention, as the term is used in transformers, is already misleading and an abuse of the term; calling this faux attention is even more egregious and bordering on meaningless. It makes for a succinct title, however.

Approximation the first

A quadratic approximation of exp ( x ) for nonnegative x near zero is

f + ( x ) = 1 + x + x 2 / 2 .

Noting that exp ( x ) = 1 / exp ( x ) gives us

f ( x ) = [ 1 / f + ( x ) when x < 0 otherwise f + ( x ) ] = [ 1 / ( 1 x + x 2 / 2 ) when x < 0 otherwise 1 + x + x 2 / 2 ] .

Noting that e x = ( e x / a ) a gives us

g ( x ) = f ( x / 4 ) 4 = ( f ( x / 4 ) 2 ) 2 .

This can be seen here.

A straightforward implementation parallelizes well:

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  for i in 0..16 {
    let rr = x[i] * 0.25;
    let sq = rr * rr * 0.5;
    let pos = (1.0 + rr) + sq;
    let neg = (1.0  rr) + sq;
    let fst = if x[i] < 0.0 { neg.recip() } else { pos };
    let snd = fst * fst;
    y[i] = snd * snd;
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x3e800000
.CONST1: .long 0x3f000000
.CONST2: .long 0x3f800000
vexpps:
  vmovups zmm0, zmmword ptr [rdi]
  vbroadcastss  zmm2, dword ptr [rip + .CONST2]
  vxorps  xmm5, xmm5, xmm5
  vmulps  zmm1, zmm0, dword ptr [rip + .CONST0]{1to16}
  vcmpltps  k1, zmm0, zmm5
  vmulps  zmm3, zmm1, zmm1
  vmulps  zmm3, zmm3, dword ptr [rip + .CONST1]{1to16}
  vaddps  zmm4, zmm1, zmm2
  vsubps  zmm0, zmm2, zmm1
  vaddps  zmm4, zmm4, zmm3
  vaddps  zmm0, zmm0, zmm3
  vdivps  zmm4  {k1}, zmm2, zmm0
  vmulps  zmm0, zmm4, zmm4
  vmulps  zmm0, zmm0, zmm0
  vmovups zmmword ptr [rsi], zmm0
  vzeroupper
  ret

For comparison, I believe the implementation of exp used in Rust can be seen here, which is not parallelized by the compiler.

The approximation has middling accuracy:

 input        approx       actual      abs error    rel error
−7.000000    +0.002977    +0.000912    +0.002065    +2.264217
−6.000000    +0.005791    +0.002479    +0.003312    +1.336334
−5.000000    +0.011844    +0.006738    +0.005106    +0.757864
−4.000000    +0.025600    +0.018316    +0.007284    +0.397713
−3.000000    +0.058742    +0.049787    +0.008955    +0.179859
−2.000000    +0.143412    +0.135335    +0.008077    +0.059682
−1.000000    +0.371077    +0.367879    +0.003198    +0.008693
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.694856    +2.718282    −0.023426    −0.008618
+2.000000    +6.972900    +7.389056    −0.416156    −0.056321
+3.000000   +17.023682   +20.085537    −3.061855    −0.152441
+4.000000   +39.062500   +54.598148   −15.535648    −0.284545
+5.000000   +84.428101  +148.413162   −63.985062    −0.431128
+6.000000  +172.676025  +403.428802  −230.752777    −0.571979
+7.000000  +335.955963 +1096.633179  −760.677246    −0.693648

Approximation the second

Adding a third-order term and reducing the range further, seen here,

f + ( x ) = 1 + x + x 2 / 2 + x 3 / 6 g ( x ) = f ( x / 8 ) 8 = ( ( f ( x / 8 ) 2 ) 2 ) 2

improves this at the cost of three additional multiplications:

 input        approx       actual      abs error    rel error
−7.000000    +0.001006    +0.000912    +0.000095    +0.103716
−6.000000    +0.002628    +0.002479    +0.000149    +0.060299
−5.000000    +0.006951    +0.006738    +0.000213    +0.031565
−4.000000    +0.018574    +0.018316    +0.000259    +0.014124
−3.000000    +0.050031    +0.049787    +0.000244    +0.004906
−2.000000    +0.135480    +0.135335    +0.000145    +0.001068
−1.000000    +0.367906    +0.367879    +0.000027    +0.000073
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.718082    +2.718282    −0.000200    −0.000074
+2.000000    +7.381175    +7.389056    −0.007882    −0.001067
+3.000000   +19.987476   +20.085537    −0.098061    −0.004882
+4.000000   +53.837746   +54.598148    −0.760403    −0.013927
+5.000000  +143.871826  +148.413162    −4.541336    −0.030599
+6.000000  +380.485840  +403.428802   −22.942963    −0.056870
+7.000000  +993.582458 +1096.633179  −103.050720    −0.093970

Approximation the third

However, we can do better if we are willing to abuse the format of ieee 754 single-precision floating-point numbers. The following is from A Fast, Compact Approximation of the Exponential Function by Nicol Schraudolph, which can be read here. My thanks to dear Cosmo, who found this and implemented it.

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  const F : u32 = f32::MANTISSA_DIGITS  1;
  const A : f32 = (1 << F) as f32 / std::f32::consts::LN_2;
  const B : f32 = (127u32 << F) as f32;

  for i in 0..16 {
    let aff = x[i].mul_add(A, B);
    let rnd = unsafe { aff.to_int_unchecked::<u32>() };
    y[i] = f32::from_bits(rnd);
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x4b38aa3b
.CONST1: .long 0x4e7e0000
vexpps:
  vmovups      zmm1, zmmword ptr [rdi]
  vbroadcastss zmm0, dword ptr [rip + .CONST0]
  vfmadd213ps  zmm0, zmm1, dword ptr [rip + .CONST1]{1to16}
  vcvttps2udq  zmm0, zmm0
  vmovups      zmmword ptr [rsi], zmm0
  vzeroupper
  ret
 input        approx       actual      abs error    rel error
−7.000000    +0.000928    +0.000912    +0.000016    +0.017994
−6.000000    +0.002625    +0.002479    +0.000146    +0.058864
−5.000000    +0.006979    +0.006738    +0.000241    +0.035716
−4.000000    +0.019207    +0.018316    +0.000891    +0.048641
−3.000000    +0.052247    +0.049787    +0.002460    +0.049415
−2.000000    +0.139326    +0.135335    +0.003991    +0.029488
−1.000000    +0.389326    +0.367879    +0.021447    +0.058298
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.885376    +2.718282    +0.167094    +0.061471
+2.000000    +7.541565    +7.389056    +0.152509    +0.020640
+3.000000   +21.249268   +20.085537    +1.163731    +0.057939
+4.000000   +56.665039   +54.598148    +2.066891    +0.037856
+5.000000  +155.324219  +148.413162    +6.911057    +0.046566
+6.000000  +423.980469  +403.428802   +20.551666    +0.050942
+7.000000 +1125.234375 +1096.633179   +28.601196    +0.026081

This is effectively a piecewise linear function, as can be seen here. Intuitively, this is because an ieee 754 single-precision floating-point number is layed out as 1 sign bit (most significant), 8 expo­nent bits, and 23 fraction bits (least significant). There are 24 bits in the significand but only 23 bits are stored, because the most-significant bit is always 1 and can therefore be implicit. The routine performs an affine trans­formation of x and reïnterprets the result as a floating-point number; as x increases, the signifi­cand increases linearly; eventually, the significand overflows and the exponent increments, increasing the gradient. This repeats as x continues to increase. We can make this precise:

u = ( x / log 2 + 127 ) × 223 s = 1 + ( u mod 223 ) / 223 p = u / 223 127 f ( x ) = s × 2 p

After simplifying:

s = 1 + ( ( x / log 2 ) mod 1 ) p = x / log 2

And so we have

f ( x ) = ( 1 + ( ( x / log 2 ) mod 1 ) ) × exp ( x / log 2 × log 2 ) ,

as can be seen here.

Because f is a piecewise linear function, its derivative is a piecewise step function,

d d x f ( x ) = 1 log 2 × exp ( x / log 2 × log 2 ) ,

as can be seen here.

Note that the error is not monotonic on either side of zero, and there are regions where even the quadratic approximation has smaller error. The first two approximations are usually better than this approximation near zero, but this behaves well over the entire domain. More accurately, the preimage of positive values that are not subnormal nor infinity.

Approximation the fourth

We can improve the accuracy of this approximation by again making use of the fact that exp ( x ) = 1 / exp ( x ) . If f is the piecewise linear function, then

g ( x ) = ( f ( x ) + 1 / f ( x ) ) / 2 .

This can be optimized to the following. My thanks to Jonathan Hallström, who made this observation and implemented it. (The use of intrinsics is unfortunately necessary to coax the compiler to emit vrcp14ps.)

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  unsafe {
    use std::arch::x86_64::{  // I think you mean “std::arch::amd64”
      __m512              as f32x16,
      _mm512_set1_ps      as broadcast,
      _mm512_fmadd_ps     as mul_add,
      _mm512_fnmadd_ps    as neg_mul_add,
      _mm512_cvtps_epi32  as convert,
      _mm512_castsi512_ps as transmute,
      _mm512_rcp14_ps     as recip,
      _mm512_add_ps       as add
    };
    const F  : u32 = f32::MANTISSA_DIGITS  1;
    const A  : f32 = (1 << F) as f32 / std::f32::consts::LN_2;
    const B0 : f32 = ((127u32  1) << F) as f32;
    const B1 : f32 = ((127u32 + 1) << F) as f32;

    let a  = broadcast(A);
    let b0 = broadcast(B0);
    let b1 = broadcast(B1);

    let xs = *(x.as_ptr() as *const f32x16);

    let aff0 = mul_add(xs, a, b0);
    let rnd0 = transmute(convert(aff0));

    let aff1 = neg_mul_add(xs, a, b1);
    let rnd1 = transmute(convert(aff1));
    let inv1 = recip(rnd1);

    let avg = add(rnd0, inv1);
    *y = *(&raw const avg as *const [f32; 16]);
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x4b38aa3b
.CONST1: .long 0x4e7c0000
.CONST2: .long 0x4e800000
vexpps:
  vmovaps      zmm0, zmmword ptr [rdi]
  vbroadcastss zmm1, dword ptr [rip + .CONST0]
  vbroadcastss zmm2, dword ptr [rip + .CONST1]
  vfmadd231ps  zmm2, zmm0, zmm1
  vfnmadd213ps zmm1, zmm0, dword ptr [rip + .CONST2]{1to16}
  vcvtps2dq    zmm2, zmm2
  vcvtps2dq    zmm0, zmm1
  vrcp14ps     zmm0, zmm0
  vaddps       zmm0, zmm0, zmm2
  vmovups      zmmword ptr [rsi], zmm0
  vzeroupper
  ret
 input        approx       actual      abs error    rel error
−7.000000    +0.000909    +0.000912    −0.000003    −0.003707
−6.000000    +0.002492    +0.002479    +0.000013    +0.005186
−5.000000    +0.006708    +0.006738    −0.000030    −0.004385
−4.000000    +0.018427    +0.018316    +0.000111    +0.006081
−3.000000    +0.049653    +0.049787    −0.000134    −0.002685
−2.000000    +0.135962    +0.135335    +0.000627    +0.004634
−1.000000    +0.367951    +0.367879    +0.000072    +0.000196
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.726982    +2.718282    +0.008700    +0.003201
+2.000000    +7.359528    +7.389056    −0.029529    −0.003996
+3.000000   +20.194458   +20.085537    +0.108921    +0.005423
+4.000000   +54.365723   +54.598148    −0.232426    −0.004257
+5.000000  +149.310547  +148.413162    +0.897385    +0.006047
+6.000000  +402.488281  +403.428802    −0.940521    −0.002331
+7.000000 +1101.234375 +1096.633179    +4.601196    +0.004196

Adjusting the constants can reduce the error a skosh further, but zero then no longer maps to one.

We can tidy up g ( x ) and then find its derivative. We start by expanding 1 / f ( x ) like so:

( 1 + ( ( x / log 2 ) mod 1 ) ) 1 × exp ( x / log 2 × log 2 ) 1 = ( 2 ( ( x / log 2 ) mod 1 ) ) 1 × exp ( log 2 x / log 2 × log 2 ) 1

Abbreviating x m = ( ( x / log 2 ) mod 1 ) and x t = x / log 2 and continuing to simplify:

= ( 2 x m ) 1 × exp ( log 2 + x t × log 2 ) = 2 × ( 2 x m ) 1 × exp ( x t × log 2 )

Then g ( x ) = ( f ( x ) + 1 / f ( x ) ) / 2 is

½ × ( 1 + x m ) × exp ( x t × log 2 ) + ½ × 2 × ( 2 x m ) × exp ( x t × log 2 ) ,

and factoring the common term,

g ( x ) = ( 1 + x m 2 + 1 2 x m ) × exp ( x t × log 2 ) .

The derivative of the first term is

d d x ( 1 + x m 2 + 1 2 x m ) = 1 2 log 2 + 1 log 2 × ( 2 x m ) 2 = 1 log 2 ( 1 2 + 1 ( 2 x m ) 2 ) .

The derivative of the second term is simply zero.

The derivative of their product is then

d d x g ( x ) = ( 1 2 + 1 ( 2 x m ) 2 ) × exp ( x t × log 2 ) / log 2 ,

which can be seen here; substituting x m and x t with their definitions, this is

( 1 2 + 1 ( 2 ( ( x / log 2 ) mod 1 ) ) 2 ) × exp ( x / log 2 × log 2 ) / log 2 .

This is evidently continuous at x k log 2 for integer k ; we will show this derivative is also continuous at x = k log 2 , establishing that the derivative is continuous everywhere and that g ( x ) is C ¹ smooth.

First we consider the left-hand limit, writing a for k log 2 :

lim x a ( ( x / log 2 ) mod 1 ) = 1 lim x a x / log 2 = k 1 lim x a d d x g ( x ) = ( 1 2 + 1 ( 2 1 ) 2 ) × exp ( ( k 1 ) × log 2 ) / log 2 = 3 2 log 2 exp ( ( k 1 ) × log 2 ) = 3 2 log 2 × exp ( k log 2 ) exp ( log 2 ) = 3 4 log 2 exp ( k log 2 ) .

Now the right-hand limit:

lim x a + ( ( x / log 2 ) mod 1 ) = 0 lim x a + x / log 2 = k lim x a + d d x g ( x ) = ( 1 2 + 1 ( 2 0 ) 2 ) × exp ( k × log 2 ) / log 2 = 3 4 log 2 exp ( k log 2 ) .

These are equal, so the limit exists, and it is equal to the value of the derivative at k log 2 , so the derivative is continuous there.

The Proposition

Won’t you propagate our gradients with me,
my love?

I was interested in finding an approximation of the exponential function so that I could perform interpolated arg max (“softmax”) without unreasonable overhead. I had in mind the following architecture.

The typical nnue begins with two vectors, i w and i b , of ones and zeroes which encode the positions of white’s and black’s pieces. There are then two sets of weights, the first-layer matrixes W 1 s and W 1 p , that lead to the vectors σ 1 w = W 1 s i w + W 1 p i b and σ 1 b = W 1 p i w + W 1 s i b . The s and p stand for same side and opposite side.

These are relabelled as σ 1 m = σ 1 w and σ 1 r = σ 1 b if white is the side to move and black is the side in repose or vice versa, as σ 1 m = σ 1 b and σ 1 r = σ 1 w , if black is the side to move and white is the side in repose.

Then an activation function f 1 is applied elementwise to yield the vectors a 1 m = f 1 ( σ 1 m ) and a 1 r = f 1 ( σ 1 r ) . These are then reduced by pairwise multi­pli­cation: suppose the number of elements of a 1 m or a 1 r is n 1 ; then x 1 m = a 1 m [ 1 , . . . , n 1 / 2 ] a 1 m [ n 1 / 2 + 1 , . . . , n 1 ] and x 1 r = a 1 r [ 1 , . . . , n 1 / 2 ] a 1 r [ n 1 / 2 + 1 , . . . , n 1 ] . We then concatenate these to form x 1 = x 1 m x 1 r .

Now the idea: there are two sets of weights, the second-layer matrixes W 2 v and W 2 k , that lead to the vectors σ 2 v = W 2 v x 1 and σ 2 k = W 2 k x 1 of equal dimension.

Then exponentiation is performed elementwise on σ 2 k to yield a 2 k = exp ( σ 2 k ) and this is then 𝓁 1 -normalized so that we obtain k 2 = a 2 k / a 2 k 1 or k 2 = s a 2 k where

s = ( i a 2 k [ i ] ) 1

since each element of a 2 k is positive. (This is the interpolated arg max.)

We then perform pairwise multiplication and obtain x 2 = σ 2 v k 2 .

Finally the elements of x 2 are summed (note that this can indeed simply be a sum rather than a weighted sum or dot product) and the output of the network is then

σ 3 = i x 2 [ i ] .

Here the interpolated arg max and pairwise multiplication function as “soft output heads”. This was Cosmo’s observation, although she used the common term “output buckets”, which I simply refuse to adopt. But one might instead insert additional layers after x 1 x 2 before the output.

One might also imagine applying elementwise an activation function to σ 2 v before pairwise multiplication: a 2 v = f 2 ( σ 2 k ) and then x 2 = a 2 v k 2 . I would be curious to first see the result without f 2 , however.