The underlying theory is similar to that of $\mu$P, which is that to maintain a constant variance in the activations, **the weights of the matrix must scale** $1/\sqrt{in\_d}$.
**Activation Scaling** The reason to control the scale of our activations is that all the other components of our models, floating point formats, activation functions, SoftMax's and so on all assume a fixed scale. It's easier to just control the scale of our activations than the scale of everything else. And as we'll see it's very easy to control this.
**Theory Matters** The importance of activation variance has been known for a long time, in fact there is a [quite old paper](https://arxiv.org/abs/1706.02515) that accurately predicted all the properties of a good activation function years based on this idea. Since then there have been countless published studies of activation functions that didn't match the properties predicted by the paper, and as predicted they were all worse. The paper that discovered the *swish* activation function came out the same year, and was discovered automatically via search.
The application of theory and conceptual understanding of *why* things work is really important to doing good machine learning research in order to better use the resources available to us.
## Meat
Ok, so *why* should we scale the weights in a matrix scale $1/\sqrt{in\_d}$?
Well, given that **we want to control the variance** of the activations for other reasons, we can assume constant variance in our analysis. We'll see that scaling the weights keeps the variance constant. If we start with constant variance inputs, and the output of each matrix is constant variance, then (aside from activation fns, attn, etc.) the input to each matrix will be constant variance too.
### Conceptual
When we multiply our weight matrix by our input vector, we are doing a series of multiply and adds. However, the number of additions will increase the variance in our activations, so we need to scale it by some function of the number of inputs.
Intuitively, we shouldn't just average it. We'd expect a bunch of random additions, some positive, some negative, to cancel each other out to some degree. It turns out the expected scale happens to be $\sqrt{in\_d}$.
I find the CLT intuitive so the proof is sufficient reason to show why it's a sqrt specifically to me, and it is below. I don't really know of any other interpretation though.
### Proof
**Define** $\vec x \sim N_d(0, \sigma^2_x)$ as our input vector, we'll assume i.i.d
**Define** $\vec \theta \sim N_d(0, \sigma^2_\theta)$ same deal as our input vector
*Matrix multiplications can be represented as vectorized dot-products, so we will show the scaling for a single neuron. (This is also why the weights don't scale relative to the $out\_d$.)*
**Define** $y = \vec \theta^T \vec x$ where $y$ is our output vector
**Then** The expected distribution of our output using the central limit theorem is $N(0, d)$:
$
\vec \theta^T \vec x = \sum_i^d \theta_i x_i = \sum_i^d N(0, \sigma^2_x) \times N(0, \sigma^2_\theta) = \sum_i^d N(0, \sigma^2_x\sigma^2_\theta)
$
As you can see, our output neuron's expected activation is the summation of $d$ values with an expected variance of $\sigma^2 = \sigma^2_x \sigma^2_\theta$.
**Assume** The central limit theorem is true
$
\bar X_n = \dfrac{\sum^n_i X_i}{n}
$
$
\lim_{n\rightarrow \infty} \sqrt{n} (\bar X_n - \mu) \approx N(0, \sigma^2)
$
**Assume** $\mu=0$
**Then** we can move the $\sqrt{n}$ term into the definition of $\bar X_n$ and partially cancel the denominator
$
\lim_{n \rightarrow \infty} \dfrac{\sum^n_i X_i}{\sqrt{n}} = N(0, \sigma^2)
$
**Then** we can multiply both sides by $\sqrt{n}$, showing that a summation of i.i.d variables scales $O(\sqrt{n})$
$
\lim_{n \rightarrow \infty} \sum^n_i X_i = N(0, \sigma^2) / \sqrt{n}
$
**Assume** $d$ is large, in LLMs $d$ is usually $\sim 10^4$
**Then** we can apply the above result to our expected distribution summation, to retain constant variance, we must divide by the sqrt of the number of summations.
$
\vec \theta^T \vec x = \sqrt{d} \times N(0, \sigma^2_x\sigma^2_\theta)
$
$
\pmatrix{\dfrac{\vec \theta}{\sqrt{d}}}^T \vec x = N(0, \sigma^2_x\sigma^2_\theta)
$
And there we have it, by scaling $\sqrt{in\_d}$, the variance becomes the expected variance of the input elements and elements in our weights and is no longer dependent on the size of our model.
# Next
Now that we know the parameters should be scaled $1/\sqrt{d}$...
**... We should scale the gradient by it too** [[Gradient Scaling]]
**... We should rescale our parameters in low precision** [[Precision Scaling]]