There are many interesting structural properties to the matrices in neural networks as we would expect. Recently, a lot of attention has been brought to the singular values in particular, which oddly follow a power rule. ![[Pasted image 20250502094953.png]] Typically this is just chalked up to an interesting property that is perhaps exploitable. Perhaps a measure of generalization or full-trained-ness. However, there is a much more interesting connection in this, which is the relationship between NN matrices, hierarchical networks and hyperbolic geometries. # Theory ## The Singular Value Distribution The power law distribution of fully trained NNs contrasts those of purely random matrices, which follow a kind of linear slope by index. The overall probability distribution of these values is given by the [Marchenko-Pastur](https://en.wikipedia.org/wiki/Marchenko%E2%80%93Pastur_distribution) distribution. ![[Pasted image 20250606113854.png]] As you can see, the variance/deviation from this trend is extremely tiny, hardly visible. However, NN matrices, momentums, gradients all start out with Marchenko-Pastur distributions but gradually become power law distributed as training commences. This is cool because it allows us to directly observe the "phase change" that occurs in LLMs and other deep neural networks. But, it raises an interesting question. **Why** are they power law distributed? In theory there should be no underlying structure common to all the various tasks we train neural networks on, yet across virtually every real world environment, NNs reliably converge to power law distributed svals.. *So.. what does it mean for the momentum and parameters to have power law distributed svals?* ## Spectral Graph Theory Those familiar with [spectral graph theory](https://en.wikipedia.org/wiki/Spectral_graph_theory) would know that a heavy tailed decay of singular values (svals) of the adjacency, laplacian or incidence matrix would hint at the possibility of the graph being scale-free and possibly hierarchically organized. Hierarchical networks are graphs for which there is no single clustering "scale", every cluster is made of smaller clusters that are themselves made of smaller clusters and so on. Intuitively, it would make sense that most of the world's data is organized in such a manner. After all the neurons in our [brain form hierarchical networks](https://arxiv.org/abs/1004.3153), so do [social networks](https://arxiv.org/abs/0811.0484), [weblinks and ip networks](https://arxiv.org/abs/cond-mat/0206130), food chains, etc. etc. Even the very fabric of our space time is hyperbolic (we'll get into it in a second). It seems that virtually all intelligent systems converge on hierarchical organization. Why would LLMs be different? What this also hints at is the possibility that knowledge in LLMs are not really stored on the individual neuron level, but that actually there is a deeper underlying hierarchical structure to what models learn. ## Hyperbolic Geometries ![[Pasted image 20250606122014.png]] But how do we know for certain if LLMs really are forming hierarchical graphs in their weights? Well, one approach is to modify the geometry in which they operate. Traditionally, LLMs operate in Euclidian space, where the scale of the space doesn't change with the coordinates. However, there are other kinds of geometries. For example, the geometry on the surface of a sphere, which is the geometry your GPS operates in. Importantly, 3 geometries in particular: flat, hyperbolic and spherical can all be used to embed different kinds of graphs. One can lay out a 'regular' graph with no meaningful macroscopic structure on a flat geometry, spherical geometries can encode loops and closed graphs, and **hyperbolic geometries** can encode **trees and hierarchical networks**. If we train a neural network to operate in hyperbolic space instead of Euclidian space, we should expect them to perform better on tasks involving hierarchical data, especially at smaller scales. And in fact, we do observe this on [embedding tasks](https://arxiv.org/abs/1705.08039), [translation](https://arxiv.org/abs/2405.15481v1), [GNNs](https://arxiv.org/abs/1910.12933) and so on. ![[Pasted image 20250606134400.png]] *From "[Sparse Spectral Training and Inference on Euclidean and Hyperbolic Neural Networks](https://arxiv.org/abs/2405.15481v1)"* ## Statistics for Hierarchical Knowledge An interesting statistic one can calculate on adjacency matrices is the effective degree. It's kinda like the average degree of a node in the graph. I won't really delve into how exactly the formula for this was devised but you'll just have to trust me and o3 that it works. $ k_\text{eff} = m \dfrac{||W^T W||_F^2}{||W||_F^4} $ *where $m$ is your output dimension* What's interesting about this formula is that you can apply it to matrices with negative weights and directed graphs. This is what we have in the case of NN matrices. Here are some $k_\text{eff}$ of common matrices: | Matrix | $k_{\text{eff}}$ | | ----------------------- | ---------------- | | Zero | 0 | | Identity | 1 | | Random Normal (500x500) | 2 | | 3-regular | 3 | | Star (n=100) | 100 | | Path (n=100) | 1.5 | For random matrices the expected $k_\text{eff}$ is given by the expression $1+m/n$, where $n$ is the input dimension and $m$ is the output. What's interesting is that as we'd expect, if we project from a small number of inputs to a large number of outputs, the expected degree goes up. For example, for a 4x expansion ratio, $\mathbb{E}[k_\text{eff}] = 5$. Importantly, the variance of this metric is extremely low for the matrix dimensions we use in deep learning. Which means we can use this as a way to detect structure. If the calculated $k_\text{eff}$ is even a little above the expectation, then there is definitely some degree to which clustering is going on. # Empirics ## Effective Degree of Fully Trained Models For this example we'll look at the weights of `Llama 3.1 8b base`, we'll plot the $k_\text{eff}/\mathbb{E}[k_\text{eff}]$ for each matrix across the layers to see how much above random the matrices in the model are clustered. ![[Pasted image 20250606125651.png]] As we can see, *every* matrix in a fully trained model is reliably above the pure random baseline. Additionally, there are a few interesting patterns in this data that we can make sense of. **Non-Linearities increase** $k_\text{eff}$: Notice that all the matrices that sit *behind* a major non-linearity (key, query, gate) all have high $k_\text{eff}$. Having a non-linearity allows the following matrix (o, v, down) to convert the cluster assignments into residual stream deltas. So, if any of the matrices are going to be clustering algorithms, it would be the ones that sit before the non-linearities. **MHA increases $k_\text{eff}$**: Notice that the **query** matrices have a very high effective rank. Intuitively, because in MHA the same key can be attended to by multiple queries, we'd expect the same "node" in the input to attend to multiple "nodes" in the query output. After all a central hypothesis to MHA is the idea that multiple similar queries to the same key are still more useful than one query per key. Additionally, the output matrix, which has the task of converting all the multiple head outputs into a residual stream delta is the only matrix that breaks from the trend of matrices preceding non-linearities. This is probably because of MHA. $k_\text{eff}$ **in last layers**: Intuitively, the final layers of the model should also exhibit significantly higher degrees than previous layers, as those layers attend to many possible latent nodes to very few vocab nodes. ## Proof of Hierarchical Organization On it's own, power law distributed svals do not prove hierarchical organization. However, there are a few more tests we can perform. For example, if the power law svals are in fact due to the hierarchical organization of the matrix, then shuffling the elements in the matrix should return any statistics back to random baseline and destroy any previously observed structure. ![[Pasted image 20250606153915.png]] ![[Pasted image 20250606160218.png]] As we can see, shuffling the values returns all the statistical properties back to their random baseline. As far as I know, this is definitive evidence for the internal structure being hierarchical. Just to be absolutely certain, we can perform something called a hyperbolic SVD on our matrix and directly measure the curvature implied by the matrix: ![[Pasted image 20250606161219.png]] As we can see, we have a negative curvature, showing the geometry is in fact hyperbolic! # Implications **Mechanistic Interpretability** This has some interesting implications for mechinterp research, the hierarchical nature of knowledge in the model hints at a deeper underlying structure to how the model thinks. One more fundamental than the activations of individual neurons. In theory we should observe transformer circuits form hierarchical networks, where many neurons in the same matrix serve related but different functions. This is particularly important if we ever want to *comprehend* how these transformers actually work, hierarchically organized patterns are much easier for humans to understand than just giant lists of observed patterns. **Architecture** I'm relatively bearish on hyperbolic transformers, numerical instability combined with the fact that scale tends to wash out gains from these kinds of things make it difficult to justify. In fact in SST they already observed the advantage diminishing quickly with scale. However, I think it has other important implications. **Pretraining** Wouldn't you like to know, weather boy? # Conclusion So, transformers organize knowledge hierarchically, or perhaps the knowledge of our world is inherently hierarchical. Additionally, we have methods to measure the degree of this organization.. **... How does this organization evolve during pretraining?** (Coming soon) **... Then the gradient must be a re-clustering.. so what?** (Coming soon) # Appendix ### Bi-Directional Effective Degree If you want the bidirectional effective degree for whatever reason, our adjacency matrix becomes: $A = \begin{bmatrix} 0 & M \\ M^T & 0 \end{bmatrix} $ *Therefor* $\mathbb{E}[k_\text{eff}] = 1+\dfrac{n^2+m^2}{2nm}$ It's kinda cool that this interpretation also explains why backprop uses the transpose. ## Code (JAX) ### Effective Degree ```python @jax.jit def k_eff(W): # careful with this one m = W.shape[0] # squaring number only requires slightly more exponent bits W = W.astype(jnp.bfloat16) # need more mantissa for accum f2 = jnp.sum(jnp.square(W), dtype=jnp.float32) s4 = jnp.sum(jnp.square(W.T @ W), dtype=jnp.float32) return m * s4 / (f2**2) ``` ### Hyperbolic SVD ```python import jax, jax.numpy as jnp, optax def ldot(x, y): # Lorentzian product return -x[..., 0]*y[..., 0] + jnp.sum(x[..., 1:]*y[..., 1:], -1) def project(x, eps=1e-8): # ℍ_{‑1} projection (x₀>0) n = jnp.where(ldot(x, x) >= -eps, -eps, ldot(x, x)) x = x / jnp.sqrt(-n)[..., None] return jnp.where(x[..., :1] < 0, -x, x) def _init(k, shape): z = jax.random.normal(k, shape) z = jnp.concatenate([jnp.ones(shape[:-1]+(1,)), 0.01*z[..., 1:]], -1) return project(z) def hyperbolic_svd(W, d=2, steps=4000, lr=1e-3, key=jax.random.PRNGKey(0)): m, n = W.shape # make target strictly > 0 (=> in range of –⟨U,V⟩_L R²) w_min = jnp.min(W) shift = -w_min + 0.1 if w_min <= 0 else 0.0 # ensure positivity scale = jnp.max(W + shift) T = (W + shift) / scale + 1.0 # T ∈ [1,2] k1, k2 = jax.random.split(key) U, V = _init(k1, (m, d+1)), _init(k2, (n, d+1)) log_R = jnp.zeros(()) # R≈1 opt = optax.chain(optax.clip_by_global_norm(1.), optax.adam(lr)) state = opt.init((U, V, log_R)) def loss(p): U, V, lR = p R = jnp.exp(jnp.clip(lR, -4.0, 4.0)) # 0.018 ≤ R ≤ 54.6 pred = R**2 * (-ldot(U[:, None, :], V[None, :, :])) return jnp.mean((pred - T) ** 2) @jax.jit def step(p, s): val, grads = jax.value_and_grad(loss)(p) upd, s = opt.update(grads, s, p) p = jax.tree_util.tree_map(lambda a,b: a+b, p, upd) p = (project(p[0]), project(p[1]), jnp.clip(p[2], -4.0, 4.0)) return p, s, val params = (U, V, log_R) for _ in range(steps): params, state, _ = step(params, state) U, V, lR = params R = jnp.exp(lR); curvature = -1.0 / (R**2) rmse = jnp.sqrt(loss(params)) * scale # back‑scale distortion return U, V, curvature, rmse # ---- example usage ---- key = jax.random.PRNGKey(42) W = tensor_collections['mlp.gate'][0] W0 = W / jnp.max(jnp.abs(W)) (U, V, curv, dist) = hyperbolic_svd(W, d=16, steps=4000) print("Original: curvature =", float(curv), "| distortion =", float(dist)) ```