I love writing good code and I like spending exorbitant amounts of time on the little details. So here's an article on the objectively perfect configuration system for python. # Introduction We want a system that balances specificity with readability and re-use. The maximal specificity is manually defining every sub module configuration etc. in one large file, making local variables for any shared or tied variables (i.e. the model dimension of the attn and mlp blocks). However, this is very bad in terms of re-use as if we want to make a new version we need to copy the whole file. **If we're doing research and our code needs constant refactoring to accommodate the weird shit I do then I need a better configuration system.** # Theory ## Prior Work We could use a flat configuration file where the tied relationships between the sub-configs are encoded in utility methods, but this has poor specificity. If I want to untie something I can't without rewriting the configuration method. i.e. I want to make the down projection heads bf16 but up projections fp8. We could use a hierarchical configuration, but then we have loads of redundancy where we have to specify the `model_d` over and over. The more default methods we use the more bloated the code becomes (trust me, I tried). ## Functional Hierarchical Configs What we want is some way for the value-tying relationships to be easily modified. Value tying relationships are really nice to encode in functions because then I don't need to reinvent programmatic configs from first principles like everyone else. We can easily encode hierarchical configurations in functions like this ```python def llama( model_d, embed=Embed.create, norm=Norm.create ): return LlamaModel( model_d, embed=embed(256, model_d), lm_head=norm(model_d), ... ) model_v1 = llama(1024, embed=lambda md: embed(65536, md)) model_v2 = llama(1024, embed=lambda md: embed(65536, md, init=lambda md, vd, rng: normal(stddev=2, shape=(md, vd), rng))) ``` However, that is somewhat cumbersome when just replacing the args to a function (i.e. to change the dtype, activation function, matmul precision, etc.). It's also annoying if we want to replace the arguments or implementation of nested functions. In particular the `lambda ...` part is irritating. ## Legible Functional Hierarchical Configuration To aleviate this issue we can use some python metaprogramming. We can wrap the method with a function that will unpack a special type that specifies to replace the default function with the same function but with some of the arguments replaced. Much like a `functools.partial` but readable. ```python from configurator import config, override @config def llama( model_d, vocab_size, num_layers=12, embed=embed, norm=norm, ... ): return LlamaModel(embed=embed(...), ...) @config def fp8_norm(...): ... # example configuration llama_fn = llama.override( # functools.partial equivalent vocab_size=65535, # override regular args num_layers=24, # override args without changing the function embed=override(dtype=jnp.bfloat16), # override args and function norm=fp8_norm.override( recentering="recenter", # nesting works too initializer=nnx.initializer.zeros ) ) # call the function model = llama_fn(1024) ``` # Implementation This configuration system actually works really really well. The research I do is usually really diverse and creative which calls for frequent rewrites and requires maximum code re-usability. These configs are perfect for that. ```python """ Simple configuration system for hierarchical function and class composition. The @config decorator adds an .override() method to functions and classes. - For functions: Returns the function with an added .override() method - For classes: Returns the class with an added .override() class method Example: @config def create_embed(vocab_size, d, init="normal"): return Embed(vocab_size, d, init) @config def create_fp8_embed(vocab_size, d, init="normal"): return ... @config class Llama: def __init__(self, model_d, vocab_size=32000, embed=create_embed): self.embed = embed(vocab_size, model_d) # Create configured factory function llama_3 = Llama.override( vocab_size=128256, # Override args of default function embed=override(init="kaiming"), ) llama_fp8 = Llama.override( vocab_size=128256, # Replace the default function entirely embed=create_fp8_embed.override(init="kaiming"), ) model = llama_3(4096) # llama_3 is a function, not a class Key points: - @config adds .override() to functions and classes mutably! - func.override(**kwargs) returns a new function with baked-in overrides - Class.override(**kwargs) returns a factory function (not a class) - override(**kwargs) is a special marker to override args of a default function - Type system recognizes decorated objects through Annotated[T, HasOverride] - Chaining overrides copies & overwrites the overrides from the last @Author: _ueaj """ import functools import inspect from functools import wraps, partial from typing import Callable, TypeVar, Protocol, Any, ParamSpec, Type, overload, Annotated P = ParamSpec('P') T = TypeVar('T') C = TypeVar('C') # Simple protocol for the override method class HasOverride(Protocol): """Protocol for types that have the override method.""" @classmethod def override(cls, **overrides: Any) -> Callable[..., Any]: ... @overload def config(obj: Type[C]) -> Annotated[Type[C], HasOverride]: ... @overload def config(obj: Callable[P, T]) -> Annotated[Callable[P, T], HasOverride]: ... def config(obj: C) -> Annotated[C, HasOverride]: """Add .override() method to a function or class. The returned object is annotated to indicate it has both: - Its original type/interface (C) - The HasOverride protocol (providing the .override() method) """ if inspect.isclass(obj): return _config_class(obj) if callable(obj): return _config_function(obj) raise TypeError("@config can only be used on callables (functions or classes)") def override(**kwargs): """Special marker for overriding just the arguments of a default.""" return ('override', kwargs) def _apply_and_call(target, sig, overrides, args, kwargs): """Helper to bind arguments, apply overrides, and call the target.""" bound = sig.bind_partial(*args, **kwargs) bound.apply_defaults() varkw_name = next((p.name for p in sig.parameters.values() if p.kind == inspect.Parameter.VAR_KEYWORD), None) for name, value in overrides.items(): if name in sig.parameters: if isinstance(value, tuple) and value[0] == 'override': _, override_kwargs = value current = bound.arguments.get(name, sig.parameters[name].default) if hasattr(current, 'override'): bound.arguments[name] = current.override(**override_kwargs) elif callable(current) and current != inspect.Parameter.empty: bound.arguments[name] = partial(current, **override_kwargs) else: raise TypeError( f"Argument '{name}' is not a configurable or callable, so its arguments cannot be overridden." ) else: bound.arguments[name] = value elif varkw_name: if varkw_name not in bound.arguments: bound.arguments[varkw_name] = {} bound.arguments[varkw_name][name] = value return target(*bound.args, **bound.kwargs) def _validate_overrides(target_name, sig, overrides): """Checks if override keys are valid for the given signature.""" accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) for name in overrides: if name not in sig.parameters and not accepts_kwargs: available = [k for k in sig.parameters.keys() if k != 'self'] raise TypeError( f"Invalid override '{name}' for '{target_name}'. " f"Available parameters: {available}" ) def _override_method(func, **overrides: Any): if hasattr(func, 'overrides'): overrides = {**func.overrides, **overrides} func = func.__wrapped__ sig = inspect.signature(func) _validate_overrides(func.__name__, sig, overrides) @wraps(func) def overridden(*args: P.args, **kwargs: P.kwargs) -> T: return _apply_and_call(func, sig, overrides, args, kwargs) overridden.overrides = overrides return config(overridden) def _config_function(func: Callable[P, T]): """Adds .override() to a function.""" func.override = functools.partial(_override_method, func) return func def _config_class(cls): """Adds .override() to a class that returns a factory function. The class itself is modified in-place to add the override method. The type system sees this through the Annotated return type. """ def init(*args, **kwargs) -> C: return cls(*args, **kwargs) # Add the override method cls.override = functools.partial(_override_method, init) return cls ```