I love writing good code and I like spending exorbitant amounts of time on the little details. So here's an article on the objectively perfect configuration system for python.
# Introduction
We want a system that balances specificity with readability and re-use. The maximal specificity is manually defining every sub module configuration etc. in one large file, making local variables for any shared or tied variables (i.e. the model dimension of the attn and mlp blocks). However, this is very bad in terms of re-use as if we want to make a new version we need to copy the whole file.
**If we're doing research and our code needs constant refactoring to accommodate the weird shit I do then I need a better configuration system.**
# Theory
## Prior Work
We could use a flat configuration file where the tied relationships between the sub-configs are encoded in utility methods, but this has poor specificity. If I want to untie something I can't without rewriting the configuration method. i.e. I want to make the down projection heads bf16 but up projections fp8.
We could use a hierarchical configuration, but then we have loads of redundancy where we have to specify the `model_d` over and over. The more default methods we use the more bloated the code becomes (trust me, I tried).
## Functional Hierarchical Configs
What we want is some way for the value-tying relationships to be easily modified. Value tying relationships are really nice to encode in functions because then I don't need to reinvent programmatic configs from first principles like everyone else.
We can easily encode hierarchical configurations in functions like this
```python
def llama(
model_d,
embed=Embed.create,
norm=Norm.create
):
return LlamaModel(
model_d,
embed=embed(256, model_d),
lm_head=norm(model_d),
...
)
model_v1 = llama(1024, embed=lambda md: embed(65536, md))
model_v2 = llama(1024, embed=lambda md: embed(65536, md, init=lambda md, vd, rng: normal(stddev=2, shape=(md, vd), rng)))
```
However, that is somewhat cumbersome when just replacing the args to a function (i.e. to change the dtype, activation function, matmul precision, etc.). It's also annoying if we want to replace the arguments or implementation of nested functions. In particular the `lambda ...` part is irritating.
## Legible Functional Hierarchical Configuration
To aleviate this issue we can use some python metaprogramming. We can wrap the method with a function that will unpack a special type that specifies to replace the default function with the same function but with some of the arguments replaced. Much like a `functools.partial` but readable.
```python
from configurator import config, override
@config
def llama(
model_d,
vocab_size,
num_layers=12,
embed=embed,
norm=norm,
...
):
return LlamaModel(embed=embed(...), ...)
@config
def fp8_norm(...):
...
# example configuration
llama_fn = llama.override(
# functools.partial equivalent
vocab_size=65535,
# override regular args
num_layers=24,
# override args without changing the function
embed=override(dtype=jnp.bfloat16),
# override args and function
norm=fp8_norm.override(
recentering="recenter",
# nesting works too
initializer=nnx.initializer.zeros
)
)
# call the function
model = llama_fn(1024)
```
# Implementation
This configuration system actually works really really well. The research I do is usually really diverse and creative which calls for frequent rewrites and requires maximum code re-usability. These configs are perfect for that.
```python
"""
Simple configuration system for hierarchical function and class composition.
The @config decorator adds an .override() method to functions and classes.
- For functions: Returns the function with an added .override() method
- For classes: Returns the class with an added .override() class method
Example:
@config
def create_embed(vocab_size, d, init="normal"):
return Embed(vocab_size, d, init)
@config
def create_fp8_embed(vocab_size, d, init="normal"):
return ...
@config
class Llama:
def __init__(self, model_d, vocab_size=32000, embed=create_embed):
self.embed = embed(vocab_size, model_d)
# Create configured factory function
llama_3 = Llama.override(
vocab_size=128256,
# Override args of default function
embed=override(init="kaiming"),
)
llama_fp8 = Llama.override(
vocab_size=128256,
# Replace the default function entirely
embed=create_fp8_embed.override(init="kaiming"),
)
model = llama_3(4096) # llama_3 is a function, not a class
Key points:
- @config adds .override() to functions and classes mutably!
- func.override(**kwargs) returns a new function with baked-in overrides
- Class.override(**kwargs) returns a factory function (not a class)
- override(**kwargs) is a special marker to override args of a default function
- Type system recognizes decorated objects through Annotated[T, HasOverride]
- Chaining overrides copies & overwrites the overrides from the last
@Author: _ueaj
"""
import functools
import inspect
from functools import wraps, partial
from typing import Callable, TypeVar, Protocol, Any, ParamSpec, Type, overload, Annotated
P = ParamSpec('P')
T = TypeVar('T')
C = TypeVar('C')
# Simple protocol for the override method
class HasOverride(Protocol):
"""Protocol for types that have the override method."""
@classmethod
def override(cls, **overrides: Any) -> Callable[..., Any]: ...
@overload
def config(obj: Type[C]) -> Annotated[Type[C], HasOverride]:
...
@overload
def config(obj: Callable[P, T]) -> Annotated[Callable[P, T], HasOverride]:
...
def config(obj: C) -> Annotated[C, HasOverride]:
"""Add .override() method to a function or class.
The returned object is annotated to indicate it has both:
- Its original type/interface (C)
- The HasOverride protocol (providing the .override() method)
"""
if inspect.isclass(obj):
return _config_class(obj)
if callable(obj):
return _config_function(obj)
raise TypeError("@config can only be used on callables (functions or classes)")
def override(**kwargs):
"""Special marker for overriding just the arguments of a default."""
return ('override', kwargs)
def _apply_and_call(target, sig, overrides, args, kwargs):
"""Helper to bind arguments, apply overrides, and call the target."""
bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
varkw_name = next((p.name for p in sig.parameters.values() if p.kind == inspect.Parameter.VAR_KEYWORD), None)
for name, value in overrides.items():
if name in sig.parameters:
if isinstance(value, tuple) and value[0] == 'override':
_, override_kwargs = value
current = bound.arguments.get(name, sig.parameters[name].default)
if hasattr(current, 'override'):
bound.arguments[name] = current.override(**override_kwargs)
elif callable(current) and current != inspect.Parameter.empty:
bound.arguments[name] = partial(current, **override_kwargs)
else:
raise TypeError(
f"Argument '{name}' is not a configurable or callable, so its arguments cannot be overridden."
)
else:
bound.arguments[name] = value
elif varkw_name:
if varkw_name not in bound.arguments:
bound.arguments[varkw_name] = {}
bound.arguments[varkw_name][name] = value
return target(*bound.args, **bound.kwargs)
def _validate_overrides(target_name, sig, overrides):
"""Checks if override keys are valid for the given signature."""
accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
for name in overrides:
if name not in sig.parameters and not accepts_kwargs:
available = [k for k in sig.parameters.keys() if k != 'self']
raise TypeError(
f"Invalid override '{name}' for '{target_name}'. "
f"Available parameters: {available}"
)
def _override_method(func, **overrides: Any):
if hasattr(func, 'overrides'):
overrides = {**func.overrides, **overrides}
func = func.__wrapped__
sig = inspect.signature(func)
_validate_overrides(func.__name__, sig, overrides)
@wraps(func)
def overridden(*args: P.args, **kwargs: P.kwargs) -> T:
return _apply_and_call(func, sig, overrides, args, kwargs)
overridden.overrides = overrides
return config(overridden)
def _config_function(func: Callable[P, T]):
"""Adds .override() to a function."""
func.override = functools.partial(_override_method, func)
return func
def _config_class(cls):
"""Adds .override() to a class that returns a factory function.
The class itself is modified in-place to add the override method.
The type system sees this through the Annotated return type.
"""
def init(*args, **kwargs) -> C:
return cls(*args, **kwargs)
# Add the override method
cls.override = functools.partial(_override_method, init)
return cls
```