#### Interesting
- https://huggingface.co/blog/afmck/jax-tutorial
- https://kidger.site/thoughts/torch2jax/
---
#### Pure Functional Paradigm
>[!info] Pure Functions in JAX
>If a function ticks the following conditions, it is said to be pure:
>- **All the inputs get in from the parameters**.
>- **All the outputs are returned from the function**.
>- Upon sending the same inputs, the **results should always be the same**. This is why JAX uses [[JAX Pseudorandom Numbers|special pseudorandom numbers]].
---
In jax, MANY [[- Python Programming Language -|Python]] control structures have compiled, functional equivalents to be able to work with [[jax.jit - Just-in-Time Compilation with JAX|JIT]], Autoiff and vectorization.
**Control Flow Primitives**
| Concept | JAX Function | Python/Numpy Equivalent | Description |
| --------------------- | -------------------- | ----------------------------------- | ----------------------------------------------- |
| Conditional (`if`) | `jax.lax.cond` | `if cond:` | Optimized, statically structured `if` |
| Multi-way branch | `jax.lax.switch` | `if-elif` / `switch-case` | Branches by index into a list of funcs |
| Fixed-range loop | `jax.lax.fori_loop` | `for i in range(start, stop):` | Unrolled, efficient static loop |
| Stateful loop (scan) | `jax.lax.scan` | `for x in xs:` with state | Recurrent-style loop with carry + stack |
| While loop | `jax.lax.while_loop` | `while cond:` | Compiled while loop with carry state |
| Vectorized map | `jax.vmap` | `np.vectorize`, list comprehensions | Efficiently maps a function over leading axis |
| JIT-compiled function | `jax.jit` | - | Compiles entire function ([[XLA\|XLA]] backend) |
**Reduction & Collection Primitives**
|Concept|JAX Function|Python/Numpy Equivalent|Description|
|---|---|---|---|
|Sum|`jax.numpy.sum`|`np.sum`|Sum over array dimensions|
|Product|`jax.numpy.prod`|`np.prod`|Product over axes|
|Mean|`jax.numpy.mean`|`np.mean`|Average across axes|
|Max/Min|`jax.numpy.max/min`|`np.max`, `np.min`|Elementwise extreme values|
|General reduction|`jax.lax.reduce`|`np.reduce` (custom ops)|Custom binary reductions|
|Cumulative sum|`jax.numpy.cumsum`|`np.cumsum`|Cumulative sum|
|Map (with state)|`jax.lax.scan`|`for loop` with accumulation|RNN-style loop, builds stacked output|
|Parallel sum/mean|`jax.lax.psum/pmean`|N/A (only in multi-device)|Reductions across parallel axes (vmap/pmap)|
|Axis naming|`axis_name` in `vmap/pmap`|N/A|Tags for collective operations|
---
```dataview
TABLE
file.mday AS "Last Modified",
length(file.inlinks) AS "Links In",
length(file.outlinks) AS "Links Out"
FROM #JAX
SORT file.inlinks desc
```