In ML when asked about why deep learning works, the phrase "local minima are rare in high dimensional spaces" often comes up as the explanation. This is very reductive and could easily be turned into a more precise phrase, but it works well enough. Curiously however, people don't seem to apply this logic to attention layers or sequence modeling. # The KV Cache For self attention, a major issue is that the kv cache dominates the resource costs of a model after a certain sequence length. This matters a lot in the agentic and reasoning eras, where models can think and act over millions of tokens. In order to reduce these costs, we've developed ways of reducing the kv cache size at around the same performance, like GQA, MLA, etc. However, this is fundamentally at odds with scale-pilling. Deep learning works precisely because there are so many extra dimensions in the model than is actually needed. Attention is also an extremely complicated circuit to learn. The k, q, v and o matrices all have to line up perfectly for any coherent time mixing to occur. This is very difficult to learn and so it stands to reason that in order for it to learn efficiently, one would need to vastly increase it's size in the network. So how far can one go before this saturates? # Experiment 1: KV Scaling So to test this, we'll start with a 1b parameter model with an analogous architecture and already doubled kv cache size of llama 1b. This is done simply to start the kv dim at 128, as llama 1b's default is 64 as opposed to 70b's 128. ### ICL Metric: ARC-AGI 1 To test this the model is trained on trained just on regular fineweb, but the eval loss is measured on the ARC-AGI-1 output grids. ARC-AGI has quite nice multiscale ICL learning properties. The model doesn't know of the benchmark so in order of increasing difficulty, it has to be able to discern the following things from context: 1. the fact that most of the tokens are digits 2. the fact that most of them occur in rows 3. the fact they are arranged as grids 4. the input/output pair format 5. the n trains / 1 tests format 6. the actual pattern recognition task In principle the result is a fairly smooth loss curve that actually captures the models ability to adapt on the fly. All of them will be severely undertrained for the actual ARC-AGI benchmark especially because it'll just be trained on fineweb. This measures the OOD ICL generalization of the model and not things like knowledge learnt or pure capacity. Whether or not this is actually a good measure is questionable, and as will be seen later the actual results are flipped for ARC-AGI vs regular internet loss, which is interesting. ## Results These recipes are pretty well made, I made sure to give it proper optimizer group splitting, i.e. giving the rmsnorm it's own optimizer and lr, the embeddings get their own as well to control max embedding magnitude, the lm_head gets a muP scaled adam, and all the other matrices get muon. Unfortunately, I lost the 2x wandb run, but just know that it wasn't substantially better than the 1x run. However, the 4x run on the other hand *was* substantially better. This is very interesting, for reference that means the `o_proj` head is 16k x 2k. This is far beyond what one would expect scaling to hold for if it was simply larger representation capacity, the 2k residual bottleneck is so thin that this is most likely explained by the larger exploration space instead. And even then we haven't hit saturation. ![[Pasted image 20260418212925.png]] A ~.08 improvement on overall loss on a task this hard is actually quite substantial, and I imagine the gap would continue to grow with more training and larger kvs. I'm down like 300$ on this project so I wont scale further than this, if anyone would like to sponsor me or smth then maybe but I have more important other projects I'm doing in parallel that cost me even more. # Experiment 2: RAT MLA A very elegant solution to kv cache bottlenecks is for the network to learn it's own kv cache compression. This sounds great in theory and it's great in practice too. This is the principle behind MLA. However, it has the problem of raising the relationship degree from 4 to 5, now the latent matrix has to be in the right configuration too. This makes it much harder to learn, and even whitening optimizers like muon will still struggle to some degree. ![[MLA.png]] I didn't run this one to competition but based on the main loss curve and this one I felt it was fine to terminate it. Even with 4x the heads with a latent dim of 512, it performed barely better, if not worse than 1x the heads. (This is with muon) However, one can take inspiration from QAT ("Quantized Aware Training"). The principle is to store the weights in a higher precision and then cast them during the forward pass, so the activations and subsequent gradients are optimizing from the point in space the weights will be during inference. What if instead of going from high precision to low precision, we went from full rank to low rank? This is the idea behind RAT ("Rank Aware Training"). To implement it is fairly simple, store high precision full rank weights, and then during the forward pass perform a truncated SVD to obtain the low rank representation of the matrix. (An incremental truncated SVD is probably not impossible to make if someone wants to try this in practice at scale.). To ensure we get full rank gradients for dW, you need to replace it's gradient / backwards pass with an STE. ## Results ![[RAT.png]] As you can see, it actually fully recovers the performance of the quad kv! Interestingly though, the loss on internet data is significantly worse. This is possibly a sign of the representation vs exploration bottleneck, i.e. the RAT version is learning the same generalizing circuits but doesn't have the same overall capacity to model other, lower order aspects in the data, which is actually what we want. ![[RAT-fineweb.png]] # Conclusion Seems to work, code below. I think I'd need more scale to actually validate this, bigger models, etc. but my other projects are more important and draining my wallet Next steps in order if someone else wants to pick this. 1. validate with larger models and more tokens 2. validate ARC-AGI-1 ICL proxy validity 3. incremental truncated SVD implementation 4. adding QAT for the kv cache itself with an STE like in QAT-LLM for further compression The code is [here](https://github.com/Ueaj-Kerman/iclmax) for anyone that wants to take this on. The repo code currently is in pytorch, completely vibecoded with opus 4.7. I actually don't know how to use pytorch even though I use it for my job. Instead I either specify everything in natural language or I write it in JAX and tell claude to figure it out. So there's probably a lot of bugs, but I did put some effort into ensuring the optimizer recipe was actually well made.