Abstraction used in JAX to handle various kinds of structures.
...
We can also operate on whole pytrees, e.g through
- ...
---
#### Partial Function Evaluation using `jax.tree_util.Partial`
Funtionality of partial function evaluation via `functools.partial` can be replicated in a way that is compatible with [[- JAX -|JAX transformations]].
```python
import jax.numpy as jnp
add_one = Partial(jnp.add, 1)
add_one(2)
```
```bash
Array(3, dtype=int32, weak_type=True)
```
---