Abstraction used in JAX to handle various kinds of structures. ... We can also operate on whole pytrees, e.g through - ... --- #### Partial Function Evaluation using `jax.tree_util.Partial` Funtionality of partial function evaluation via `functools.partial` can be replicated in a way that is compatible with [[- JAX -|JAX transformations]]. ```python import jax.numpy as jnp add_one = Partial(jnp.add, 1) add_one(2) ``` ```bash Array(3, dtype=int32, weak_type=True) ``` ---