lt;10^{-9}$. # Theory ## The Quantization Barrier Consider a concrete example in FP8E4M3 format. Let's examine what happens when we try to add a small update to an FP8 parameter: **Example:** Consider the FP8E4M3 value: $\theta = 1.5 \times 2^{-1} = 0.75$ In FP8E4M3 binary representation: $s=0$, $e=6$ (biased), $m=100_2$ (i.e., $m=4$) The next representable FP8E4M3 value is: $\theta_{\text{next}} = 1.625 \times 2^{-1} = 0.8125$ This means the smallest change we can represent is: $\Delta_{\text{min}} = 0.8125 - 0.75 = 0.0625$ Now suppose we have a small gradient update: $\delta = 0.03$ When we compute $\theta + \delta = 0.75 + 0.03 = 0.78$ and round to FP8E4M3, we must choose between: - $0.75$ (distance: $0.03$) - $0.8125$ (distance: $0.0325$) Since $0.03 < 0.0325$, the result rounds back to $0.75$. The update vanishes entirely: $\text{round}_{\text{FP8}}(0.75 + 0.03) = 0.75$ ## Collective Precision The learning rate can be interpreted in one of two ways: scaling the magnitude of the overall update, or scaling the magnitude of each individual parameter. Usually these are the same thing: $||\gamma \delta|| = \gamma ||\delta||$ However, in low precision they are not. Due to the aforementioned rounding issue the learning rate can zero out large numbers of the update. Some collective properties of the tensor, though like magnitude, are higher precision than the individual elements. This property can be exploited to derive a update algorithm that attempts to preserve the overall update magnitude rather than each individual parameter magnitude. ## Algorithm To mitigate this issue, we can simply perform a binary search on the observed update magnitude after adding the update with the current candidate learning rate. ```python def sq_norm(x): return jnp.sum(jnp.square(x.astype(jnp.float32))) def precision_aware_update( params: jax.Array, update: jax.Array, target_lr: float | jax.Array, rtol: float | None = 1e-2, max_iters: int = 32 ): """ Apply update scaled to achieve target norm reduction despite precision limits. Args: params: Current parameters update: Normalized update direction (e.g., from muon) target_lr: Desired learning rate (norm reduction factor) rtol: Relative tolerance for achieving target norm max_iters: Maximum binary search iterations Returns: new_params, actual_norm_change """ # Binary search for scaling factor alpha_low = jnp.float32(0.0) alpha_high = jnp.float32(1.0) # Assuming normalized update target_norm = target_lr * sq_norm(update) for i in range(max_iters): alpha = (alpha_low + alpha_high) / 2 # Apply update and measure actual change new_params = (params.astype(update.dtype) + alpha * update).astype(params.dtype) actual_norm = sq_norm(new_params - params) # Check if we're close enough relative_error = jnp.abs(actual_norm - target_norm) / target_norm # Early exit condition (for non-JIT mode) if rtol is not None and relative_error < rtol: break # Adjust search bounds using JAX-compatible conditionals alpha_low = jnp.where(actual_norm < target_norm, alpha, alpha_low) alpha_high = jnp.where(actual_norm < target_norm, alpha_high, alpha) return new_params, actual_norm, target_norm ``` # Empirics ## Precision First, let's generate 2 random tensors one in fp8 representing the parameters and one in fp32 representing the gradient. ```python rng_key = jax.random.PRNGKey(0) params = jax.random.normal(rng_key, (1024, 1024), jnp.float32).astype(jnp.float8_e4m3fn) update = jax.random.normal(rng_key, (1024, 1024), jnp.float32) ``` Calculate the error of the magnitude and average per-parameter error compared to a full precision update ```python full_prec = params.astype(jnp.float32) + lr * update naive_prec = (params + lr * update).astype(params.dtype) low_prec = precision_aware_update(params, update, lr) stoc_prec = stochastic_update(param, update, lr) mag_err = norm(update) - norm(low_prec) param_err = ((params - low_prec) / update).mean() ``` ![[low_prec_fp8.png]] ![[low_prec_fp16.png]] ![[low_prec_bf16.png]] **Naive Addition**: Zeros out rapidly as the update falls below minimum learning rate **Stochastic Rounding**: This method works for longer, but eventually the magnitudes stop aligning and the error rate rises **Iterative**: This is our algorithm, and is far more stable than any of the others, it works for learning rates much smaller than any other method without increasing per-parameter error and inducing noise in the model ## Pretraining However, although the norm and the error rate seems better, it doesn't necessarily mean it works. In this example we pretrain a 200M bf16 llama-like model for ~60M tokens on fineweb with learning rate $1e-7$. Here we compare fp32 master weights and the iterative algorithm, you can see it takes much longer to converge and the weights move more slowly overall likely due to the sparser updates. ![[low_prec_debug.png]] # Conclusion Truthfully this is a failure. As training continues the low precision version falls further and further behind and fails to meaningfully update the parameters. I understand what is wrong here and frankly it's not worth fixing. In real world distributed training scenarios, with things like ZeRO storing an extra set of master weights is fine. I'm not really interested in small scale training either so I will leave it here.