feat. [@tensorqt](https://x.com/tensorqt)
So @tensorqt and I have been working on a new optimizer. Unfortunately it turns out another group had the same idea but for adam ([AdEMAmix](https://arxiv.org/abs/2409.03137), Pagliardini et al., 2024). However, the way in which we derived this new optimizer is conceptually very different, and thus has some key differences which I believe make it a superior alternative (more so than just swapping adam out for [Muon](https://kellerjordan.github.io/posts/muon/), Jordan 2024).
# Introduction
In [[Hyperbolic Space]] we showed that the knowledge content in LLMs might have a multiscale hierarchical structure. Interestingly, this property was rediscovered in mechanistic interpretability research. A paper showed that LLM features were better represented as [hierarchical ones](https://arxiv.org/abs/2506.10920), and another blog found that [SAE features exhibit hyperbolic structures](https://guydar.substack.com/p/mini-report-sae-features-exhibit).
Philosophically, this agrees with my outlook on the world, which is that higher order effects and categories are actually extremely important and everywhere, and pretty much all meaningfully complex systems have multiscale properties. Even more important is that the attention to these higher order properties is what governs generalization, as you are peering into the deep underlying trends between causes which will hold for farther out of distribution and not simply cause-to-cause patterns which may only generalize for a few samples OOD.
Personally, while my short term memory is abysmal, my long term memory recall is really good as I spend a lot of time introspecting and thinking to myself. I remember scenes from my childhood in incredible detail, with a particularly good ability to recall by relevance. This enables me to ask and answer higher order questions about my life like "what governs my attraction to various traits?", "what do I care about?", "what personalities am I composed of?", "am I normal?".
In the same way that the higher order properties of the world are emergent of the lower order properties, the higher order observations and concepts *inside my head* are emergent from the lower order observations. I can remember lessons from earlier in my life for longer, giving me more data to work with when observing subtle trends through my life but I also have some non-zero short term and medium term memory.
So let's test this idea. Luckily I have a strange box in my house that I can shove billions of words into and have it speak back to me if I'm smart enough. So let's convert these insights into LLM optimizer technology.
### Related Work
I also found another group found these multiscale features in loss landscapes, totally disconnected from hyperbolic/hierarchical/mechinterp research as well. They designed an algorithm called [MrSGD](https://arxiv.org/pdf/2402.03021), though it is computationally complex and rather odious to implement and use. This algorithm can kinda sorta be thought of as an efficient approximation of that algorithm updated with the preconditioned momentum paradigm.
It seems that this same concept has come up in many separate fields and that there is likely some truth to the intuition. They used a bunch of fancy math that frankly I don't really understand but whatever.
# Methodology
**First let's setup our analogs**:
- To take the place of the brain we will use a transformer with the usual Transformer++ recipe (modern architectural improvements like RMSNorm, SwiGLU/GeGLU, RoPE) but with an activation function after the soft attention. 150M parameters, configuration can be found at the bottom of this paper
- To take the place of whatever algorithm the brain uses to learn from its environment, we will of course use backpropagation ([[Associative Transformer Local Learning Rule|one day we will slay the wretched beast]])
- To take the place of the various degrees of long-short term memory in the brain, we will use multiple momentum buffers instead of just one. In this optimizer we used 3
- To take the place of whatever process learns from these memories and encodes these insights, we will use the [Muon](https://kellerjordan.github.io/posts/muon/) orthogonalizer, which will amplify subtle algorithms contained in the gradient and that generalize across the effective batch size set by the various momentum parameters.
From other work and experiments the following things were found to be helpful and were used in the custom pretraining script:
1. For a given EMA $\beta_i$, scale the learning rate according to its "gradient half-life" which can be calculated via $t_\text{half}=\frac{1}{1-\beta}$, $\eta=\eta_0/t_\text{half}$
1. My model for this is that the more the model changes in between taking the last gradient, the less relevant it is to the current model, so it makes sense to lower the learning rate
2. I suspect that this bound can be reduced further to $\eta=\eta_0/\sqrt{t_\text{half}}$ due to the variance reduction from increased averaging, however I have not tested this hypothesis
2. FP32 master weights are a necessity.
3. Normalize your input embeddings, but only when they exceed a certain magnitude.
1. Normalizing them always slows learning.
4. Use recentered RMSNorm ( $\sqrt{d} (1+\vec s) \vec x / ||\vec x||$ ) when using weight decay.
5. Use the lion optimizer for norms, adam for input embeddings, adamw for lm_head and muon/multiscale-muon/adamw for tensor weights
1. This is probably a pain in the ass if you don't spend several days making a modular optimizer configuration system but I did
The full implementation of the optimizer can be found in the bottom of this page
## Experiment
For the hparam sweep, all the momentums were kept fixed, we're assuming they're pretty close to optimal because they're from the modded-nanogpt script though the token counts are different so maybe not. Also, as you will see in a moment, we're mostly looking for specific behavior that shows the multiscale behavior is being exhibited like in AdEMAmix.
- warmup for 10M tokens
- $\beta_\text{muon}=.95875$, $\vec \beta_\text{mscale}= [0.96875, 0.9921875, 0.998046875]$
- $\eta_\text{muon}=\eta_0$, $\eta_\text{mscale}=0.8333\eta_0$
To get the update sizes to match up, we divide the learning rates of all the modules by $\sqrt{\sum \eta_i}$ and by an additional coeffecient based on the observed update size. Muon in this case was implemented simply as a special case of mscale with 1 momentum buffer, thus it uses EMA over heavyball.
For the sweep, coarse-grained lrs were swept for 100M tokens for muon, adam and multiscale were done. Adam didn't compare to muon or multiscale at the 100M token range so it was dropped it in future sweeps. Then once the optimal lr was discovered, we ran a much tighter sweep on the lrs. Though in the end the lowest lr came out on top for both optimizers, so clearly the ideal learning rate changes past a certain point. In total, there were 32 sweeps
![[Pasted image 20250814165439.png]]
These initial sweeps were done on a modified [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt) repository, and in general multiscale does better at lower learning rates. The rest in here were done in a custom pretraining script that used the aforementioned tips
In the [AdEMAmix paper](https://arxiv.org/abs/2409.03137) they note that off a single example the multibuffered solution would retain the knowledge and have lower loss on that batch for longer.
![[Pasted image 20250813011550.png]]
At around 5.5 billion tokens into the run a major drop in the test loss is observed from a batch very similar to the test data, after which the multiscale optimizer starts pulling significantly ahead in wallclock descent speed. This is a similar effect to what was observed in the AdEMAmix paper, which suggests this is actually working as intended.
![[Pasted image 20250813011823.png]]
You can see in the graph that the multiscale optimizer is obviously wallclock slower per iteration which is why it took longer to complete, but I think on distributed systems making a zero-redundancy orthogonalization isn't too hard (in JAX anyways) and even still the optimizer is wallclock faster per loss.
![[Pasted image 20250813011913.png]]
To compute the overall token efficiency improvement we heavily averaged the graph until we got something more or less monotonically decreasing and took the multiplicative factor in time to calculate the token efficiency. Overall it comes out to about +11% but also optimizers can have qualitative differences beyond pure loss reduction.
![[Pasted image 20250813113908.png]]
Another interesting observation is that although the update size of muon is marginally higher than multiscale, the norms of the parameters increases faster for multiscale. This suggests less noise in the updates, gradients are being overridden less often.
## Conclusion
While this doesn't necessarily prove any of my opinions on philosophy and intelligence the fact that an existing optimizer with more rigorous methodology has proven the same principle makes me more willing to accept this idea. It also means the opportunity cost to my world model of making making a more rigorous version of this experiment over simply moving on to the next one favors the latter.
### Future steps
Initially we intended on trying to automate the scheduling of these momentum coefficients manually, but honestly it's not really worth the added complexity. We first tried loads of different options in single momentum buffers, none of which really worked until we came up with the multi-buffered solution. We tried various ideas to check for grokking and while we did find momentum schedules that worked on synthetic problems, the continual hierarchical construction needed for language learning was better served with multiple buffers.
Automated momentum scheduling is definitely on the todo still, it shouldn't be too hard?
# Implementation
```python
import functools
from functools import lru_cache
from typing import Sequence, Callable, Literal, Optional, Any
import optax
from optax import GradientTransformation, TransformInitFn, TransformUpdateFn
from jax import numpy as jnp
import jax
from flax import nnx
# stability, optional!
def clip_outliers(z: float | int = 2):
def update_fn(updates, params=None):
sigma = jax.tree.map(lambda u: jnp.sqrt(jnp.mean(jnp.square(u), dtype=jnp.float32)).astype(u.dtype), updates)
updates = jax.tree.map(lambda u, sigma: jnp.clip(u, -z * sigma, z * sigma), updates, sigma)
return updates
return optax.stateless(update_fn)
def cast_to(dtype: jnp.dtype | None) -> GradientTransformation:
return optax.stateless(
lambda updates, params: jax.tree.map(
lambda u, p: u.astype(dtype if dtype is not None else p.dtype),
updates,
params
)
)
def lerp(a, b, alpha):
return a + alpha * (b - a)
def leading_multiply(array, target):
return array.reshape((-1,) + (1,) * (len(target.shape) - 1)) * target
def multiscale_momentum(
mom_coeffs: Sequence[float],
accumulate: float | Sequence[float] | None = None,
preconditioner: Callable[[jax.Array], jax.Array] = lambda x: jnp.sign(x),
dtype=jnp.float32,
cooldown_frac: float = 0.
) -> GradientTransformation:
mom_coeffs = jnp.array(mom_coeffs)
if accumulate is None:
accumulate = jnp.ones_like(mom_coeffs)
else:
accumulate = jnp.array(accumulate)
base_scales = ((1 - mom_coeffs) / (1 - jnp.min(mom_coeffs)))
base_scales = base_scales / jnp.sqrt(base_scales.sum() + 1e-6)
cooldown = (1 - jnp.max(mom_coeffs)) / (1 - mom_coeffs)
scales = jnp.where(cooldown_frac > cooldown, 0, base_scales)
def init_fn(params):
return jax.tree.map(lambda x: jnp.zeros((len(mom_coeffs),) + x.shape, dtype=dtype), params)
def update_fn(grad, state, params=None):
state = jax.tree.map(
lambda grad, mom: (
leading_multiply(accumulate * (1 - mom_coeffs), grad[None, ...])
+ leading_multiply(mom_coeffs, mom)
).astype(dtype),
grad,
state
)
update = jax.tree.map(
lambda grad, mom: jnp.sum(
leading_multiply(
scales,
preconditioner(
leading_multiply(1 - mom_coeffs, grad[None, ...]) + leading_multiply(mom_coeffs, mom)
)
),
axis=0
),
grad,
state
)
return update, state
return GradientTransformation(init_fn, update_fn)
def scale_by_muonP(base_lr: float):
"""
Scale lr by muP (maximal update parametrization)
"""
return optax.stateless(
lambda updates, params: jax.tree.map(
lambda u, p: -max(u.shape[-1] / u.shape[-2], 1) * base_lr * u,
updates,
params
)
)
# -3 (base)
# -4
def multiscale_muon(
lr=.125,
warmup_frac=1.,
wd: Optional[float] = None,
method: Literal['muon', 'optimal'] = 'muon',
dtype=jnp.float32
):
return optax.chain(
clip_outliers(2),
multiscale_momentum(
[0.96875, 0.9921875, 0.998046875],
preconditioner=functools.partial(orthogonalize, method=method),
accumulate=[1., warmup_frac, warmup_frac ** 2],
dtype=dtype
),
optax.add_decayed_weights(wd) if wd else optax.identity(),
scale_by_muonP(lr),
cast_to(None)
)
def muon(lr=.125, wd: Optional[float] = None, beta=.95875, method: Literal['muon', 'optimal'] = 'muon'):
return optax.chain(
clip_outliers(2),
multiscale_momentum(
[beta],
preconditioner=functools.partial(orthogonalize, method=method)
),
optax.add_decayed_weights(wd) if wd else optax.identity(),
scale_by_muonP(lr),
cast_to(None)
)
def orthogonalize(
x,
ns_steps=5,
eps=1e-8,
method: Literal['muon', 'optimal'] = 'muon', # removed optimal
convergence_speed: Optional[float] = 1e-3
):
if method == 'muon':
coeffs = (3.4445, -4.7750, 2.0315) # Newton-Schulz coefficients from Muon
else:
raise ValueError(f'Unknown value for `method`: {method}')
coeffs = jnp.array(coeffs)
x_shape = x.shape
if x.ndim <= 1:
raise ValueError(f'Input must have shape (m, n), got {x.shape}')
elif x.ndim == 2:
x = x[None, ...]
elif x.ndim > 3:
x = x.reshape((-1,) + x.shape[-2:])
return jax.vmap(
functools.partial(
optax.contrib._muon.orthogonalize_via_newton_schulz,
ns_steps=ns_steps,
eps=eps,
ns_coeffs=coeffs
),
in_axes=0
)(x).reshape(x_shape)
```
## References
1. **AdEMAmix**: Pagliardini, M., Ablin, P., & Grangier, D. (2024). The AdEMAMix Optimizer: Better, Faster, Older. arXiv preprint arXiv:2409.03137. https://arxiv.org/abs/2409.03137
2. **Muon**: Jordan, K. (2024). Muon: An optimizer for hidden layers in neural networks. https://kellerjordan.github.io/posts/muon/
3. **Modded-NanoGPT**: Jordan, K. (2024). NanoGPT (124M) in 3 minutes. GitHub repository. https://github.com/KellerJordan/modded-nanogpt
4. **Hierarchical Features in LLMs**: The paper showing that LLM features were better represented as hierarchical ones (2024). arXiv:2506.10920. https://arxiv.org/abs/2506.10920
5. **SAE Features and Hyperbolic Structures**: Dar, G. (2024). Mini Report: SAE features exhibit hyperbolic structures. https://guydar.substack.com/p/mini-report-sae-features-exhibit