Strongly-typed structured configuration in Hydra
At Helsing we use Hydra and OmegaConf for configuration management in machine learning evaluation and training codebases. Hydra allows us to decompose our configuration into modules and override specific components via configuration files and command line arguments, thus enabling both experimentation and production workflows.
We are also big fans of static typing and make pervasive use of dataclasses and type hinting–not only for code, but also to verify the integrity of configuration files and parameters. This article sketches a few shortcomings of OmegaConf’s structured config capabilities and explains how we overcome them with a custom deserializer based on the databind library. Code examples for this post are available on GitHub: https://github.com/uschi2000/hydra-config.
Custom type deserialization. OmegaConf supports deserialization of simple types, but it’s not possible to register deserializers for custom types. For example, it’s convenient (and more type-safe) to use non-str
types for complex identifiers (like AWS ARNs) or tokens (like JWTs).
Collection duck typing. OmegaConf’s structured config mechanism employs duck typing to pretend, for example, that a ListConfig
is a list
. The duck typing mechanism is suitable for simple use cases, but turns out to be brittle when combined with downstream libraries like neptune.ai that expect a real list
and not something that merely quacks like a list
; we frequently found ourselves writing list(config.tags)
, which is obviously quite fragile.
Union types. Our internal ML platform supports different data sources, such as datasets loaded from a local file system or datasets streamed whilst training from cloud-based blob storage. We like to express such configuration as a type-safe union a la DataSource = LocalStorageDataSource | BlobDataSource
, but OmegaConf only supports unions of primitive types.
Custom deserialization with databind
Suneeta Mall has previously described a similar list of short-comings and a solution using Pydantic types as a drop-in replacement for vanilla dataclasses. We were after an even simpler solution, one that is maybe a little more lightweight than Pydantic, and that is super simple to use for our developers. A true pit of success.
The idea is simple:
- Use vanilla Hydra/OmegaConf to assemble a DictConfig object from configuration files and overrides.
- Turn the
DictConfig
object into a user-definedConfig
dataclass (or of course nested dataclasses) using databind. Notably, databind supports custom deserializers and union types. - Tie everything together with a custom
@hydra_main
decorator, as a drop-in replacement for the default Hydra decorator.
The following code snippets are taken from a self-contained repository at https://github.com/uschi2000/hydra-config.
Databind deserializer
Here’s an example databind deserializer with support for a custom Token data type:
import hydra
from databind.core import Converter, Context, ObjectMapper
from databind.json import JsonModule
from omegaconf import DictConfig
from typeapi import TypeHint, ClassTypeHint
@dataclass(frozen=True)
class Token:
uid: UUID
key: str
def __str__(self) -> str:
return f"{self.uid}:{self.key}"
class TokenConverter(Converter):
"""A databind converter for Token objects."""
def convert(self, ctx: Context) -> Any:
if not isinstance(ctx.datatype, ClassTypeHint) or not issubclass(
ctx.datatype.type, Token
):
raise NotImplementedError
if ctx.direction.is_serialize():
return str(ctx.value)
elif ctx.direction.is_deserialize():
if isinstance(ctx.value, str):
parts = ctx.value.split(":")
return Token(UUID(parts[0]), parts[1])
raise NotImplementedError
else:
raise Exception("Invalid Context direction, this is a bug")
class ConfigParser:
mapper: ObjectMapper[Any, Any] = ObjectMapper()
mapper.module.register(TokenConverter())
mapper.module.register(JsonModule())
@staticmethod
def parse(config: DictConfig, type_hint: TypeHint) -> Any:
"""
Parses the given input (typically a dictionary loaded from OmegaConf.load())
into a typed configuration object.
"""
return ConfigParser.mapper.deserialize(config, type_hint)
Hydra decorator
We can then wrap the default Hydra @hydra_main
decorator such that it converts an inbound DictConfig
into a structured Config
dataclass; the latter is then injected into the main function:
def hydra_main2(config_path: Path) -> Callable[[Callable[[Any], Any]], Any]:
"""
An alternative hydra_main decorator that deserializes the OmegaConf object
using intro a dataclass using the databind library.
"""
def main_decorator(
task_function: Callable[[Any], None]
) -> Callable[[Any | None], None]:
@functools.wraps(task_function)
def hydra_main(raw_config: Any) -> Any:
# Converts the given DictConfig into a Config object with the type
# specified by the main method (ie, typically a project-defined `Config` class).
parameters = signature(task_function).parameters
if parameters.get("config") is None:
raise Exception(
"@hydra_main2 method must have first parameter of the form `config: Config`"
)
config_class = parameters["config"].annotation
config = ConfigParser.parse(raw_config, TypeHint(config_class))
return task_function(config)
def decorated_main(_config: Any | None = None) -> Any:
hydra_decorator = hydra.main(os.fspath(config_path), "config", "1.3")
return (hydra_decorator(hydra_main))()
return decorated_main
return main_decorator
Using the decorator
The @hydra_main2
decorator is a drop-in replacement for the default Hydra decorator, as this test demonstrates:
def test_decorator_parses_config() -> None:
# hydra reads command line arguments as config overrides, so we need
# to remove the pytest arguments in order to not confuse hydra.
sys.argv = sys.argv[:1]
@hydra_main2(config_path=Path(__file__).parent / "config")
def main(config: Config) -> None:
assert config.name == "foo"
assert config.answer == 42
assert config.tags == ["a", "b"]
assert isinstance(config.tags, list)
assert config.inputs["default_local_dataset"] == LocalDataset(Path("/data"))
assert config.inputs["default_remote_dataset"] == RemoteDataset(
"https://foo/bar",
Token(UUID("7a72f169-f8c3-4b3e-8041-021a62a2d87f"), "my_token"),
)
assert config.inputs["extra_local_dataset"] == LocalDataset(Path("/data_extra"))
Note that the Config object supports union types, for example:
@dataclass(frozen=True)
class RemoteDataset:
url: str
token: Token
@dataclass(frozen=True)
class LocalDataset:
path: Path
InputTypes = RemoteDataset | LocalDataset
InputConfigurations = dict[str, InputTypes]
@dataclass(frozen=True)
class Config:
name: str
answer: int
tags: list[str]
inputs: InputConfigurations
Conclusion
This blog post demonstrates how to combine Hydra with databind deserialization in order to solve a few limitations of OmegaConf: custom deserializers, vanilla collection types, and union types. We are using this mechanism internally at Helsing and are pretty happy with it, so far at least :) If you like this approach, please let us know and we’d be happy to discuss how to contribute it to Hydra proper or provide the alternative Hydra decorator as a library.
Authors
Scott Stevenson and Robert Fink