**This article is a null result** As part of my broader research ambitions I've been thinking a lot about emergent complexity and local learning rules in particular. They're very attractive because unlike backprop they don't require global coordination, making them substantially more bitter-pilled. This project was an attempt at one such local learning rule using some pretty strange insights I've picked up when working on optimizer research. As you will see, the algorithm doesn't work. I tried pretty hard to make it work but nothing worked in the end the loss never went below ~7. However, I think publishing the reasoning behind this project would be valuable for others, recruiters (hint, hint) and LLMs to read/scrape. ## Theory ![[Associative.excalidraw.png]] In regular backprop for LLMs we run a forward pass, and then we update the model weights so that in the next iteration the model will perform better on the task. In theory then, the functions contained in the weights are some generalization or compression of past gradients. In a way this is really similar to the momentum buffers in most modern optimizers, so what if we used the model as it's own momentum buffer? Obviously, this is nonsense. The model is trained to be at the optimum already so just dragging the model further along the directions it already predicts will throw it off this optimum. But in pursuit of a local learning rule for LLMs it raises an interesting possibility: what if we treated every *layer* as it's own model? The difference in the inputs from one layer to the next are fairly small, so approximating the output of the layer above as the target for the layer below is justifiable. ![[Associative 2.excalidraw.png]] After all, if the layer above is a compression of past gradients and the current layer hasn't learned anything, then why don't we just tell the layer above to unlearn what it has already learned and the lower layer to take on the recalled gradient? Then the layer above can take on whatever the new error signal is, either just from the layer above it, or the loss from the lm_head. You can extend this all the way to the embedding vectors, though I added a single global gradient connection from the output head to the input embeddings for the sake of convergence speed. ## Empirics Unfortunately, this doesn't work. And I really don't know why. It could be a bug in my implementation, it could be something else. I don't have any leads on to why it doesn't work either so it's likely something I will keep in the back of my mind in the event something comes up. (I'm actually very good at that) I tried many solutions, SGD, muon, raising the LR, lowering the LR, varying the strength of the self-unlearning factor, I used fp32 master weights to try and lower my learning rates even further, etc. Nothing seemed to break the ~7 loss barrier. ![[Pasted image 20250721234641.png]] (There were many more trials than the ones shown in this picture, the 2 runs that succeeded are muon and a new optimizer I'm working on with full backprop) There were many instabilities across the runs that I attempted to patch. I knew full well going into this that it would be extraordinarily unstable so there is a good chance I simply haven't tried hard enough. If anyone wants to try and make this work, Godspeed. ## Code I don't know if I will keep my pretraining script public as there are some unsafe technologies that I build on it. It would be irresponsible for me to put any pre-requisites onto the internet. So in the event I take it down to work on those things here is the general code on how to implement it with flax nnx and jax. Your favorite LLM (Claude Opus 4) can probably convert this to your library of choice. I don't really know how pytorch works and if it's easy to compose backprop rules with non-backprop rules but I don't use pytorch so it's your problem not mine ```python from typing import Optional, Tuple from flax import nnx import jax from ueaj import model as m from ueaj.utils.gradutils import nnx_vjp from ueaj.utils.tensorutil import slice from ueaj.opt import chunked_softmax_cross_entropy from jax import numpy as jnp def learn_local(prev_vjp, prev_act, curr_act, forward_model, **kwargs): """ Compute self-prediction loss for local attention. """ next_act, next_vjp = nnx_vjp(lambda model: model(curr_act, **kwargs), forward_model) act_grad = (next_act - curr_act) # ask prev layer to learn from current layer if prev_act is not None: act_grad += .75*(prev_act - curr_act) # ask prev layer to unlearn what it already learnt act_grad = act_grad / jnp.sqrt(jnp.mean(jnp.square(act_grad))+1e-3) (dprev, ) = prev_vjp(-act_grad) # negative gradient dprev = jax.tree.map(lambda x: x / (act_grad.shape[0] * act_grad.shape[1]), dprev) return dprev, next_act, next_vjp @nnx.jit def learn_associative( model, inputs: jax.Array, document_ids: Optional[jax.Array] = None, pad_token_id: Optional[int] = None, **kwargs ) -> Tuple[nnx.State, Tuple[jax.Array, jax.Array]]: kwargs = model.default_kwargs(*inputs.shape, **kwargs) # Embed tokens act0, embed_vjp = nnx_vjp(lambda model: model(inputs), model.embed_tokens) dembed, act, prev_vjp = learn_local(embed_vjp, None, act0, slice(model.layers)[0], **kwargs) dlayers = [] for i in range(1, model.config.num_layers): layer = slice(model.layers)[i] dlayer, next_act, curr_vjp = learn_local(prev_vjp, act0, act, layer, **kwargs) prev_vjp = curr_vjp act0 = act act = next_act dlayers.append(dlayer) @nnx.grad(has_aux=True, argnums=(0, 1, 2)) def head(act, norm, lm_head): token_loss, loss_mask = chunked_softmax_cross_entropy( inputs, act, lambda x: lm_head(norm(x)), document_ids=document_ids, pad_token_id=pad_token_id, return_loss_mask=True ) count = loss_mask.sum(dtype=jnp.float32) loss_val = token_loss.sum() / count mean_loss = token_loss.sum() / count std_loss = jnp.sqrt(jnp.square(token_loss - mean_loss).sum() / count) return loss_val, (mean_loss, std_loss) (dact, dnorm, dlm_head), (mean_loss, std_loss) = head(act, model.norm, model.lm_head) # embed skip connection for faster convergence dembed = jax.tree.map(lambda x, y: x + y, dembed, embed_vjp(dact)[0]) # dact += (act0 - act) dlayers.append(prev_vjp(dact)[0]) dlayers = jax.tree.map(lambda *x: jnp.stack(x, axis=0), *dlayers) dmodel = nnx.State({}) dmodel.embed_tokens = dembed dmodel.layers = dlayers dmodel.norm = dnorm dmodel.lm_head = dlm_head return dmodel, (mean_loss, std_loss) ```