This is a short low effort blog bc I'm doing like 4 projects rn, but this has always itched me and I finally got around to it.
It is often said that deep learning works because the redundancy in representations is actually it's strength, it allows the model to explore many possible algorithms at once without being constrained to a restricted manifold. This is great and all, but we must constrain it somehow, and better architectures (constraints) produce better results. What determines how good an architecture is is often the [[Machine Learning/Theory/Anything WILL work#Machine Epistemology|inductive bias]] it places on the model.
# Attention to detail
![[channelwise.excalidraw.png]]
However, just due to accumulated tech debt and inertia, there are a number of little implicit inductive biases that we've placed on modern models that we haven't done the work to remove. For example, the use of a channelwise rmsnorm scale biases the model towards solutions that are "axis aligned" where specific *elements* encode semantically important axis. The adam optimizer also places this inductive bias on the model.
Actually, if you look at a transformer, ignoring the RMSNorm scale and adam optimizer, the entire thing is rotation invariant. One can rotate every single input matrix and every single residual stream output matrix, and get the exact same model up to precision issues. This is great! The channelwise inductive bias is completely arbitrary and not very good. Why are we artificially constraining our models?
# The fix
The fix is twofold
1. replace the per-channel RMSNorm scale with a single shared scalar
2. replace adam with a *row-wise* second moment estimator, instead of per-element
## Scalar RMSNorm
A second minor change to the RMSNorm module is to recenter it so that weight decay doesn't cause issues, in the samples both per channel and single scale were recentered.
$
\vec y_i = \dfrac{\vec x_i \times \vec \gamma_i}{RMS(\vec x)} \rightarrow \vec y = \dfrac{\vec x \times \gamma}{RMS(\vec x)}
$
In terms of expressivity the rmsnorm scale is completely redundant but optimization dynamics make it do something somehow.
## Row Adam
We matched the update size between adam and our special optimizer as well as weight decay. To implement it is pretty simple, replace the second moment estimator with this:
$
v_t = \beta_2 v_{t-1} + (1-\beta_2) ||g_t||^2/d
$
where $d$ is the row size, this keeps the update identical to adam.
![[rowadamlrsweep.png]]
And as you can see, with WSD and no hparam tuning, it beats it pretty well. Also the way that maxtext (which this was developed in) initializes weights makes this a little more apples-to-apples than usual hence the very nice aligned losses. Single scale improves loss by ~.02, the optimizer improves it by ~.07, and combining the two nets ~.11 loss improvement.
## Hparams
The hyperparameters were set as follows
```
base_lr = .1
beta_lr = base_lr * (1-beta) # q, k, v, up, gate, rms scale, embeds
residual_lr = beta_lr / sqrt(2 * num_layers) # o_proj, down
mu_lr = beta_lr / sqrt(model_d) # lm_head
```
1. embeddings were N(0, 1) init
2. output projects and lm_heads were 0 init
3. all matrices used muon
4. adamw or rowadam for lm_head and embedding
5. adamw for rmsnorm scale