**This article is a null result** **Backpropagation 4 - \_ueaj 1** Well shit, here we go again. # Why are we here again ![[trinity.excalidraw.png]] In the [[Machine Learning/Pretraining/BPTT got hands|last article]] we discussed why we must invent surrogate gradient methods if we want to dethrone the transformer. This matters for 2 important reasons: - lowering the cost of inference - allowing sequential composition across the sequence length (No, we will not solve continual learning with just any linear time architecture, it is a vastly more complex problem than this.) The first point is fairly obvious, lowering the cost of inference is a noble and important goal. I don't think, even if we do one day succeed in producing a viable surrogate gradient method, that we should make non-hybrid models with it. My prior is that on real world texts exact recall is actually a good inductive bias, as while most text is sequentially written, longer context is organized with global relationships. The second point though is a bigger issue, the transformer as it stands operates in parallel. Every token attends to every other token independently. They don't interfere with each other unlike in an LSTM for example. This limits the ability of the model to "comprehend" or "synthesize" information. The only way it can do this is by relying on the model *depth* to compose information sequentially. This works just fine for say ~64k context length, but at a million tokens of context, you can't meaningfully process this context with just model depth. You need to be able to compute *across* the whole sequence. # Convergence Criteria ![[convergence.excalidraw.png]] One minor prerequisite for understanding the blog and the reasoning: convergence criteria If you've ever actually taken a machine learning class you'll learn about convergence criteria for optimizers. The basics of it are "if the cosine similarity between your update direction and the true gradient is >0, it'll converge.. eventually". There are other ones too iirc but it's been 3(?) years so I don't remember, "wolfe" something. In my head, a surrogate gradient is just a really fancy optimizer, but instead of doing increasingly complex transformations to momentum buffers, it just makes up the answer via some approximation. As long as the cosine similarities are the same it should work as the theorem doesn't draw any kind of distinction. I haven't seen anyone prove it though, but I decided to try using this metric to see if it had any power. Unfortunately, it doesn't seem to, I don't know why though. # Round 2 The fundamental principle behind the new surrogate method works like this: ![[forward1.excalidraw.png]] First, we walk forward through the sequence as usual, we can actually perfectly accurately compute the $dQ$ in this path. But additionally, we also accumulate the gradient for the hidden state along the sequence length. ![[pass2.excalidraw.png]] When it reaches the end, it's essentially an accumulated gradient for the whole sequence, and importantly, if you *subtracted* the first token's state gradient, you'd get out the gradient for the rest of the sequence. We accumulate this buffer in fp32 so it holds on longer contexts as otherwise floating point errors would make this inviable. This essentially gives you a first order approximation of the hidden state gradient for the rest of the sequence. This doesn't teach the model sequential composition, as it is a first order approximation, but it will teach this over multiple batches. This is because we take the gradient w.r.t. the hidden state at the given token, so once the model is setting the state in a particular way, it then later learns how to enhance that state more. Anyways, from this approximate gradient, you can derive a dk and dv as whatever brings the current state in the direction of all the future gradients. This kinda sounds like nonsense, until you remember that each individual token imparts a very tiny amount of information onto the hidden state, so the first order approximation is fine. The more you scale up the hidden state, the less interference there is, and in fact this is what we observe! **Dimension 768** ![[plotsw768.png]] **Dimension 256** ![[plotsw256.png]] Both hidden states are obviously really big and the sequence length is kinda small so you can't really see it here but it's there I promise. I want to highlight how non-trivial this kind of research is. The ability to know why and how scale will simply fix the problem is an informal thing, it's just intuition and taste, and I think that's one of the most important things to have as a researcher. Ilya speaks of this. I also made a surrogate for the input dstate gradient, basically I batched all the Qs+dOs and calculated the gradient w.r.t. the input state and felt that was probably good enough. That's the plot there in the corner. It kinda works, probably well enough to allow it to be chained with regular backprop if needed. The second optimization I did was to do the opposite, which is to cut the sequence into blocks, do true backprop within the block, and then the surrogate method across the blocks. The blocks are the vertical lines you can see in the plots. This marginally improves the accuracy, but importantly it teaches the network true sequential composition across the blocks, which helps accelerate learning. # Results Unfortunately, it doesn't work :( ![[pretrain_bptt2.png]] The purple there is the baseline without any attention mechanism at all, as you can see it bottoms out at around 6.8. If our attention mechanism can reduce loss below that threshold, it means it's actually interpreting the context. Unfortunately, it just seems to diverge. This is possibly just the result of some kind of instability or something, even with all the norms, but I need to work on infrastructure to allow me to iterate on this rapidly, right now it's slow as fuck. I did also check if the q, k, v gradient *magnitudes* were correct instead of just cosine similarity, and they were as well, so I'm not sure what's wrong here. # Conclusion I don't know why this failed. I think I will take a break on surrogate methods until I can form better intuitions around them. I've certainly gotten much better and it's clear to me they are a fundamental need ([[BPTT got hands|BPTT blog]]). I don't view TTTs or any linear time architecture as a viable path to continual learning, I think it's a vastly more complex problem but I would like to lower the cost of inference as much as I can. For the record I don't even think if we solve tractable TTT training that we should make non-hybrid models with it. It was a good project. This was the first time I tried using the dot product metric / relying on convergence criteria, which is the stuff I was posting about before. Unfortunately it doesn't seem to work in practice. To be clear this wasn't the best I can do, I have more "intuition-complete" ideas too but this was easier to test and act as a signal for whether I should put in the effort on the harder ones. Unfortunately, it didn't work, and I don't know why, so I will move on to other things and think about it in the background. If all your ideas succeed, you aren't on your frontier, push harder 🫡 # Code The code can be found on my github [here](https://github.com/Ueaj-Kerman/macrogpt-jax/tree/surrogate-ttt-v3), the important stuff can be found in [here](https://github.com/Ueaj-Kerman/macrogpt-jax/tree/surrogate-ttt-v3/ueaj/model/ttt). Here's the most relevant part of the code that actually implements the surrogate in JAX ```python import functools import jax import jax.numpy as jnp def make_update_fn(fwd_fn, n_iters, wd, lr): """Create the state update function (gradient descent on reconstruction loss).""" def update_fn(state, k, v): for _ in range(n_iters): v_pred, dstate_fn = jax.vjp(lambda state: fwd_fn(state, k), state) dv = v - v_pred dstate, = dstate_fn(dv) # SGD with weight decay and tanh gradient clipping # jax.tree.map_with_path(lambda k, v: print(f"{k}: {v.dtype}"), dstate) state = jax.tree.map(lambda a, b: (1 - wd * lr) * a + lr * jax.nn.tanh(b), state, dstate) return state return update_fn def make_scan_fn(fwd_fn, n_iters, wd, lr): """Create scan function: update state first, then query.""" update_fn = make_update_fn(fwd_fn, n_iters, wd, lr) def scan_fn(state, x): k, v, q = x # Update state first (based on k, v) new_state = update_fn(state, k, v) # Then query the updated state o = fwd_fn(new_state, q) return new_state, o return scan_fn def ttt(fwd_fn, surrogate=True, n_iters=1, wd=.1, lr=.005, block_size=None): """Create TTT layer with optional surrogate gradients. The forward pass: 1. Update state using (k, v) via gradient descent 2. Query updated state with q to produce output Returns: (output, final_state) tuple The surrogate backward pass uses a two-pass approach for k/v gradients, and a batched computation for the init_state gradient. """ fwd_scan = make_scan_fn(fwd_fn, n_iters, wd, lr) update_fn = make_update_fn(fwd_fn, n_iters, wd, lr) update_fn_ref = update_fn reshape = lambda v: v unshape = lambda v: v if block_size is not None: assert block_size > 0, "block_size must be positive" def reshape_fn(v): assert v.shape[0] % block_size == 0, f"seq len ({v.shape[0]}) must be a multiple of block_size ({block_size})" return v.reshape((v.shape[0] // block_size, block_size) + v.shape[1:]) def unshape_fn(v): assert v.shape[1] == block_size, f"tensor block size ({block_size}) must be equal to block size ({block_size})" return v.reshape((v.shape[0]*block_size,) + v.shape[2:]) reshape = lambda v: jax.tree.map(reshape_fn, v) unshape = lambda v: jax.tree.map(unshape_fn, v) # fwd_fn = jax.vmap(fwd_fn, in_axes=(None, 0)) fwd_scan = functools.partial(jax.lax.scan, fwd_scan) def update_wrapped(state, k, v): return jax.lax.scan(lambda state, kv: (update_fn_ref(state, *kv), None), state, (k, v))[0] update_fn = update_wrapped if not surrogate: fwd_scan = jax.remat(fwd_scan, policy=jax.checkpoint_policies.nothing_saveable) def _ttt_fwd(k: jax.Array, v: jax.Array, q: jax.Array, state): final_state, o = jax.lax.scan(fwd_scan, state, reshape((k, v, q))) o = unshape(o) return (o, final_state), (k, v, q, state, final_state) def _ttt(k: jax.Array, v: jax.Array, q: jax.Array, state): return _ttt_fwd(k, v, q, state)[0] if not surrogate: return jax.vmap(_ttt, in_axes=(0, 0, 0, None)) # Reference to single-token scan function (before any blocking wrappers) fwd_scan_single = make_scan_fn(fwd_fn, n_iters, wd, lr) def _ttt_bwd(res, g): do, d_final_state = g # Unpack cotangent tuple k, v, q, init_state, final_state = res native_dtype = jax.tree.leaves(init_state)[0].dtype # ========================================== # PASS 1: Forward through sequence, accumulating dState sum in fp32 # Computes dstate_i = d(output_i)/d(state_i) for each token. # ========================================== def pass1_token_scan(carry, x): """Accumulate dstate for a single token.""" state, dstate_accum = carry ki, vi, qi, doi = x _, state_vjp_fn, state = jax.vjp( lambda s: tuple(reversed(fwd_scan_single(s, (ki, vi, qi)))), state, has_aux=True ) dstate, = state_vjp_fn(doi) dstate_accum = jax.tree.map( lambda acc, ds: acc + ds, dstate_accum, dstate ) return (state, dstate_accum), None # Initialize with d_final_state (gradient flowing from next block) init_accum = jax.tree.map(lambda x: x.astype(jnp.float32), d_final_state) (_, total_dstate), _ = jax.lax.scan( pass1_token_scan, (init_state, init_accum), (k, v, q, do) ) # ========================================== # PASS 2: Distribute gradients to k/v with EXACT direct + SURROGATE indirect # # For token i, the gradient has two components: # 1. DIRECT: output_i → k_i (single-step, computed exactly) # 2. INDIRECT: output_j → k_i for j > i (multi-step, surrogate approx) # # Key insight: Direct path should NOT go through the extra J_upd factor # that the surrogate dstate accumulation introduces. # ========================================== # Token-by-token surrogate VJP (for block_size=None case) def compute_kvq_grads_surrogate(si, ki, vi, qi, doi, accum_i): """Single VJP that combines direct (from do) and indirect (from accum) gradients.""" accum_native = jax.tree.map(lambda x: x.astype(native_dtype), accum_i) _, scan_vjp = jax.vjp( lambda s, k, v, q: fwd_scan_single(s, (k, v, q)), si, ki, vi, qi ) _, dk, dv, dq = scan_vjp((accum_native, doi)) return dk, dv, dq if block_size is None: # No blocking: process all tokens sequentially with surrogate def pass2_token_scan(carry, x): """Compute combined gradients via single VJP.""" state, accum_indirect = carry ki, vi, qi, doi = x dk, dv, dq = compute_kvq_grads_surrogate(state, ki, vi, qi, doi, accum_indirect) new_state = update_fn_ref(state, ki, vi) _, state_vjp_fn = jax.vjp( lambda s: fwd_fn(update_fn_ref(s, ki, vi), qi), state ) dstate_this, = state_vjp_fn(doi) new_accum_indirect = jax.tree.map( lambda acc, ds: acc - ds.astype(jnp.float32), accum_indirect, dstate_this ) return (new_state, new_accum_indirect), (dk, dv, dq) (_, _), (dk, dv, dq) = jax.lax.scan( pass2_token_scan, (init_state, total_dstate), (k, v, q, do) ) else: # With blocking: EXACT within-block gradients + SURROGATE inter-block def pass2_block_exact(carry, x): """Process block with exact within-block gradients. Key insight: Use standard VJP through the entire block for exact within-block gradients. Only use surrogate for d_state propagation between blocks. IMPORTANT: The inter-block accumulator must EXCLUDE this block's contribution before we use it. Otherwise we double-count within-block gradients (once via exact VJP, once via surrogate accumulator). """ state, accum_inter_block = carry k_block, v_block, q_block, do_block = x # First, compute this block's dstate contribution (for accumulator update) def compute_dstate(si, ki, vi, qi, doi): _, state_vjp = jax.vjp( lambda s: fwd_fn(update_fn_ref(s, ki, vi), qi), si ) dstate_i, = state_vjp(doi) return dstate_i def collect_states(s, kv): ki, vi = kv new_s = update_fn_ref(s, ki, vi) return new_s, s final_state, states = jax.lax.scan(collect_states, state, (k_block, v_block)) dstates = jax.vmap(compute_dstate)(states, k_block, v_block, q_block, do_block) # Subtract this block's contribution BEFORE using the accumulator total_dstate_block = jax.tree.map( lambda ds: ds.astype(jnp.float32).sum(axis=0), dstates ) accum_future_only = jax.tree.map( lambda acc, ds: acc - ds, accum_inter_block, total_dstate_block ) # VJP of entire block forward pass (gives EXACT within-block gradients) def block_forward(s, k, v, q): return jax.lax.scan(fwd_scan_single, s, (k, v, q)) _, block_vjp = jax.vjp( block_forward, state, k_block, v_block, q_block ) # Cotangent: (d_final_state from FUTURE blocks only, d_outputs) accum_native = jax.tree.map(lambda x: x.astype(native_dtype), accum_future_only) _, dk_block, dv_block, dq_block = block_vjp((accum_native, do_block)) return (final_state, accum_future_only), (dk_block, dv_block, dq_block) (_, _), (dk, dv, dq) = jax.lax.scan( pass2_block_exact, (init_state, total_dstate), reshape((k, v, q, do)) ) dk, dv, dq = unshape((dk, dv, dq)) # # Translate d_final_state to d_init_state via batched update VJP @jax.jit def batch_update(state): states = jax.vmap(update_fn_ref, in_axes=(None, 0, 0), out_axes=0)(state, k, v) return jax.tree.map(lambda v: v.sum(axis=0), states) _, update_vjp_fn = jax.vjp(batch_update, init_state) dstate_from_final, = update_vjp_fn(d_final_state) # Gradient from queries via batched fwd VJP _, fwd_vjp_fn = jax.vjp(lambda s: jax.vmap(fwd_fn, in_axes=(None, 0))(s, q), init_state) dstate_from_queries, = fwd_vjp_fn(do) # # Combine both contributions dstate = jax.tree.map( lambda dq, df: dq + df, dstate_from_queries, dstate_from_final ) # dstate = d_final_state return dk, dv, dq, dstate _ttt = jax.custom_vjp(_ttt) _ttt.defvjp(_ttt_fwd, _ttt_bwd) return jax.vmap(_ttt, in_axes=(0, 0, 0, None)) ```