This not just another paper highlight, though it does focus on the paper. But rather extracting the conceptual theory from a paper and theorizing how it can be extended. I saw this paper posted on twitter awhile back, and wanted to get to it, and now I have. I've partially replicated the paper (training is for later) and I now understand the class of algorithms this belongs to. The significance of this paper to me is that it unlocks a new abstraction, a way of doing things that I did not think about before. # A new abstraction If I am to explain it very simply, you can take an unparallelizable function, take an nth order approximation of it at each token with some initial state estimate, and accumulate it across the sequence length. Essentially, it ignores how much changing the state across one subset would have changed across the subsets in-between's output and directly imparts the change in the state on every future state. However, due to some fancy math nonsense, if you repeat this process multiple times it will eventually fully model the interdependence across the sequence. ![[long context.excalidraw.png]] Importantly, directly imparting the change skips how the state would evolve if it actually did go through the other parts of the sequence. This means we can trade between the number of sequential hops a model can make across the sequence and parallelism/efficiency. This is great because in practice long context behavior is not fully connected and is usually sparse. Essentially, the relationships between them are spread far apart in the sequence length but have finite depth. ### TL;DR We can extract 2 important consequences from this paper - We can compose this method with regular backprop or other surrogate gradient methods to make more precise tradeoffs - We can smoothly control the parallelism-sequential computation tradeoff of most recurrent architectures, rather than having one or the other (transformer, deltanet, transnormer vs LSTM, etc.) without sacrificing MFU # The paper ## The Math ![[para_mam.png]] In the paper they use newton iteration, which uses a second order approximation, using $O(n^2)$ space per hidden state to produce a better approximation of the function, allowing it to converge in fewer iterations. If your transition function doesn't have any mixing within the hidden state (i.e. diagonal transition matrix) then your Jacobian can be $O(n)$ space. Importantly, this function has a specific form that allows you to add two functions together by simply summing the states: $ f(\theta_a, x) + f(\theta_b, x) = f(\theta_a + \theta_b, x) $ This means you can add together two approximations for how the state will change across some subset of the sequence length, and get a new approximation that approximates how the state will change across both subsets combined. Or something, idk I'm not a math guy. ## Jacobi Iteration ![[para_hidden.png]] This concept can be extended or simplified, you can take higher or lower approximations of the function to propagate this information across the sequence length. This allows one to trade memory consumption for convergence rate, which matters for larger hidden states. ![[para_iter.png]] Using jacobi iteration (1st order) you simply provide the state delta, i.e. given the current state, here's now it ends up at the end of the subset. You then accumulate that information across the sequence length in parallel, which provides some approximation of how the state will end up at each of the time steps. With that approximation as the new starting point, you can repeat the procedure, iteratively refining your prediction. The longer the sequence length and the more complex the interdependence between the previous state and next state the more iterations it'll take to converge. Something like TTT would probably converge fairly quickly, as even though the TTT update is dependent on the previous state, there is relative independence between each of the pieces of knowledge stored in the MLP. It would only take a long time to converge if the state flip-flopped, with significant writes and overwrites. # Beyond ParaRNN What's particularly important about this is three things 1. This is a very easily generalizable abstraction in JAX, and the link to the implementation is in the appendix 1. The backwards function is still also a non-parallelizable function that can be approximated, it just runs across the sequence in reverse. This makes it very clean to implement. 2. This is a *composable* abstraction 3. This allows us to smoothly trade off between parallelism and sequential computation ## Composability That second point is extremely important and extremely valuable, it means you can mix regular backprop or other surrogate gradient methods with ParaRNN. This is a really valuable property of an abstraction, and it allows you to mix other abstractions/methods to push further along the pareto frontier than you could normally. Remember, at each timestep you get an estimate as to what the initial/final error signal is at each subset. You can simply use that as the error signal for the next subset when doing regular backprop. This means you can choose a sequence length that's just enough to saturate the parallelism of your GPU, find the number of iteration it takes to refine it, and then do regular sequential backprop across blocks of your sequence for improved accuracy. You can even combine it with other surrogate gradient methods if you want. You can also do it the other way around, do backprop in parallel across independent subsets and then ParaRNN across the blocks, since there are fewer blocks per sequence, it'll take fewer steps to converge. If someone got [[BPTT got hands|surrogate gradients on TTT]] to work, you balance between saturate the parallelism of the GPU (w/o enormous hidden states) while using less memory, unlocking longer sequence lengths. ## The pareto frontier Fundamentally, this paper is the idea is that we can compute a non-parallelizable expression in parallel, and then with some associatively composable error signal, propagate interdependent information across the sequence. Since this interdependent propagation is also parallel, it also can't fully capture true sequential computations, but if we assume that *most* of the computations are running in parallel then the approximation holds well. This means we can trade parallelizability for true sequential computation for basically any architecture. ![[tradeoffs.excalidraw.png]] *Previously we could only really choose between parallelizable and true sequential/depth.* By controlling the number of iterations, we can control the number of sequential compositions across the sequence length we want to model smoothly. This matters a lot because the architectures we have currently don't strike a good balance between parallelism (transformer, transnormer, deltanet) and depth (LSTM), but now we can control it arbitrarily. # Conclusion This paper is really cool, the takeaway imo is not the particular implementation but the conceptual line of reasoning that it unlocks. Newton, Jacobi, iterative neural approximations, etc. are all very powerful methods that can be mixed and matched with other ideas to push further along the pareto frontier. Just consider it to be another tool in your toolbox when designing new architectures and the like. ## Appendix The Jacobi utility function can be found [here](https://github.com/Ueaj-Kerman/macrogpt-jax/blob/pararnn/ueaj/model/parallel_scan.py), the newton one can be found [here](https://github.com/Ueaj-Kerman/macrogpt-jax/blob/pararnn/ueaj/model/parallel_scan_newton.py) ```python from ueaj.model.parallel_scan import parallel_scan final_h, outputs = parallel_scan( rnn_cell, h0, inputs, num_iterations=15 # Typical ) ``` What's cool is that it works for any arbitrary rnn_cell function b/c JAX is cool and awesome. ## Practical Guidelines **Use Jacobi when:** - Hidden dim < 2048 - Training on single GPU - Iteration cost << model forward pass cost **Use Newton when:** - Hidden dim > 1024 - Multi-GPU setup with abundant memory - Very long sequences where 3 iterations << 15 iterations matters **Use Sequential when:** - Sequence length < 512 (overhead not worth it) - Debugging (easier to reason about)