#### Interesting - https://huggingface.co/blog/afmck/jax-tutorial - https://kidger.site/thoughts/torch2jax/ --- #### Pure Functional Paradigm >[!info] Pure Functions in JAX >If a function ticks the following conditions, it is said to be pure: >- **All the inputs get in from the parameters**. >- **All the outputs are returned from the function**. >- Upon sending the same inputs, the **results should always be the same**. This is why JAX uses [[JAX Pseudorandom Numbers|special pseudorandom numbers]]. --- In jax, MANY [[- Python Programming Language -|Python]] control structures have compiled, functional equivalents to be able to work with [[jax.jit - Just-in-Time Compilation with JAX|JIT]], Autoiff and vectorization. **Control Flow Primitives** | Concept | JAX Function | Python/Numpy Equivalent | Description | | --------------------- | -------------------- | ----------------------------------- | ----------------------------------------------- | | Conditional (`if`) | `jax.lax.cond` | `if cond:` | Optimized, statically structured `if` | | Multi-way branch | `jax.lax.switch` | `if-elif` / `switch-case` | Branches by index into a list of funcs | | Fixed-range loop | `jax.lax.fori_loop` | `for i in range(start, stop):` | Unrolled, efficient static loop | | Stateful loop (scan) | `jax.lax.scan` | `for x in xs:` with state | Recurrent-style loop with carry + stack | | While loop | `jax.lax.while_loop` | `while cond:` | Compiled while loop with carry state | | Vectorized map | `jax.vmap` | `np.vectorize`, list comprehensions | Efficiently maps a function over leading axis | | JIT-compiled function | `jax.jit` | - | Compiles entire function ([[XLA\|XLA]] backend) | **Reduction & Collection Primitives** |Concept|JAX Function|Python/Numpy Equivalent|Description| |---|---|---|---| |Sum|`jax.numpy.sum`|`np.sum`|Sum over array dimensions| |Product|`jax.numpy.prod`|`np.prod`|Product over axes| |Mean|`jax.numpy.mean`|`np.mean`|Average across axes| |Max/Min|`jax.numpy.max/min`|`np.max`, `np.min`|Elementwise extreme values| |General reduction|`jax.lax.reduce`|`np.reduce` (custom ops)|Custom binary reductions| |Cumulative sum|`jax.numpy.cumsum`|`np.cumsum`|Cumulative sum| |Map (with state)|`jax.lax.scan`|`for loop` with accumulation|RNN-style loop, builds stacked output| |Parallel sum/mean|`jax.lax.psum/pmean`|N/A (only in multi-device)|Reductions across parallel axes (vmap/pmap)| |Axis naming|`axis_name` in `vmap/pmap`|N/A|Tags for collective operations| --- ```dataview TABLE file.mday AS "Last Modified", length(file.inlinks) AS "Links In", length(file.outlinks) AS "Links Out" FROM #JAX SORT file.inlinks desc ```