Source code for tud_lbm.config.adapter_base

"""Base class for configuration file adapters."""

from __future__ import annotations
import dataclasses
import importlib
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import Any
from tud_lbm.config.simulation_config import CONFIG_SECTION
from tud_lbm.config.simulation_config import SimulationConfig


[docs] class ConfigAdapter(ABC): """Converts a config file into a SimulationConfig and back.""" @abstractmethod
[docs] def load_raw(self, path: str) -> dict[str, Any]: """Parse *path* and return raw configuration dict. Unlike :meth:`load`, this returns the raw config dict without instantiating :class:`SimulationConfig`. Useful for detecting array fields before expansion. Args: path: Filesystem path to a configuration file. Returns: Raw config dictionary. """ ...
@abstractmethod
[docs] def load(self, path: str) -> SimulationConfig: """Parse *path* and return a :class:`SimulationConfig`. Args: path: Filesystem path to a configuration file. Returns: A validated :class:`SimulationConfig`. """ ...
@abstractmethod
[docs] def save(self, config: SimulationConfig, path: str) -> None: """Save a :class:`SimulationConfig` to *path*. Args: config: Configuration to save. path: Filesystem path where to save the configuration. """ ...
@staticmethod def _serialize_safe(value: Any) -> Any: # noqa: ANN401 """Convert tuples to lists and recursively process nested structures.""" if isinstance(value, tuple): value = list(value) if isinstance(value, dict): return {k: ConfigAdapter._serialize_safe(v) for k, v in value.items()} if isinstance(value, list): return [ConfigAdapter._serialize_safe(v) for v in value] return value @classmethod
[docs] def build_sections(cls, config: SimulationConfig) -> dict[str, Any]: """Build a format-agnostic nested dict from *config*, routed by CONFIG_SECTION metadata.""" sections = { f.name: f.metadata.get(CONFIG_SECTION, "simulation_type") for f in dataclasses.fields(SimulationConfig) } sim_type = config.sim_type skip = {"identity", "extra"} buckets: dict[str, dict[str, Any]] = defaultdict(dict) for key, value in dataclasses.asdict(config).items(): section = sections.get(key, "simulation_type") if value is None or section in skip: continue if section == "multiphase" and "multiphase" not in sim_type: continue if isinstance(value, dict): buckets[section].update(cls._serialize_safe(value)) else: buckets[section][key] = cls._serialize_safe(value) buckets["simulation_type"]["type"] = sim_type for ek, ev in (config.extra or {}).items(): buckets["simulation_type"][ek] = cls._serialize_safe(ev) return { "simulation_type": buckets.pop("simulation_type", {}), **{k: buckets[k] for k in sorted(buckets) if buckets[k]}, }
_ADAPTER_MAP: dict[str, str] = { ".toml": "tud_lbm.config.adapter_toml.TomlAdapter", }
[docs] def get_adapter(path: str) -> ConfigAdapter: """Return the appropriate adapter for *path* based on file extension.""" ext = Path(path).suffix.lower() fqn = _ADAPTER_MAP.get(ext) if not fqn: msg = f"Unsupported extension '{ext}'. Supported: {', '.join(sorted(_ADAPTER_MAP))}" raise ValueError(msg) module_path, class_name = fqn.rsplit(".", 1) return getattr(importlib.import_module(module_path), class_name)()