"""Validated, serialisable simulation configuration for TUD-LBM.
:class:`SimulationConfig` is a **frozen** Python dataclass used for
parsing, validation, and serialisation. It never enters a JIT boundary.
Usage::
from config.simulation_config import SimulationConfig
cfg = SimulationConfig(
grid_shape=(128, 128, 1),
tau=0.8,
nt=5000,
collision_scheme="bgk",
)
"""
from __future__ import annotations
import dataclasses
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Literal
from tud_lbm.config.config_overview import BASE_RESULTS_DIR
[docs]
CONFIG_SECTION: str = "config_section"
[docs]
ARRAY_ELIGIBLE: str = "array_eligible"
[docs]
NESTED_SWEEPABLE: str = "nested_sweepable"
[docs]
MIN_GRID_DIMENSIONS: int = 2
[docs]
MIN_TAU_VALUE: float = 0.5
[docs]
def array_field(
*,
default: object = dataclasses.MISSING,
default_factory: object = dataclasses.MISSING,
section: str | None = None,
nested_sweepable: bool = False,
**kwargs: object,
) -> field:
"""Field factory for array-eligible SimulationConfig fields.
Args:
default: Default value for the field.
default_factory: Default factory function for the field.
section: Config section name for serialisation routing.
nested_sweepable: If ``True``, sub-keys inside this dict field
will also be inspected for list values during Cartesian-product
expansion (e.g. ``gravity_force``, ``wetting_config``).
**kwargs: Additional keyword arguments passed to the field factory.
Returns:
A dataclass field with array-eligible metadata.
"""
metadata = dict(kwargs.pop("metadata", {}))
metadata[ARRAY_ELIGIBLE] = True
if nested_sweepable:
metadata[NESTED_SWEEPABLE] = True
if section is not None:
metadata[CONFIG_SECTION] = section
return field(default=default, default_factory=default_factory, metadata=metadata, **kwargs)
def _normalize_sequence(value: object) -> tuple[Any, ...]:
"""Ensure value is a tuple."""
return tuple(value) if not isinstance(value, tuple) else value # type: ignore[arg-type]
def _first_if_list(value: object) -> object:
"""Return first element if value is a list, otherwise return value."""
if isinstance(value, list):
return value[0] if value else value
return value
def _validate_positive(value: object, name: str) -> None:
"""Validate that value is positive."""
if value is not None and value <= 0: # type: ignore[operator]
msg = f"{name} must be positive, got {value}"
raise ValueError(msg)
def _validate_nonnegative(value: object, name: str) -> None:
"""Validate that value is non-negative."""
if value is not None and value < 0: # type: ignore[operator]
msg = f"{name} must be non-negative, got {value}"
raise ValueError(msg)
def _valid_collision_schemes() -> set[str]:
"""Get valid collision scheme names. Returns empty set if operators not loaded."""
try:
import tud_lbm.operators.collision # noqa: F401
from tud_lbm.registry import get_operator_names
return get_operator_names("collision_models")
except (ImportError, KeyError):
return set() # Operators not yet loaded - skip validation
def _valid_eos() -> set[str]:
"""Get valid EOS names. Returns empty set if operators not loaded."""
try:
import tud_lbm.operators.macroscopic # noqa: F401
from tud_lbm.registry import get_operator_names
return get_operator_names("macroscopic") - {"standard"}
except (ImportError, KeyError):
return set() # Operators not yet loaded - skip validation
def _valid_lattices() -> set[str]:
"""Get valid lattice types. Returns empty set if operators not loaded."""
try:
import tud_lbm.lattice.lattice # noqa: F401
from tud_lbm.registry import get_operator_names
return get_operator_names("lattice")
except (ImportError, KeyError):
return set() # Operators not yet loaded - skip validation
@dataclass(frozen=True)
[docs]
class SimulationConfig:
"""Validated, serialisable simulation configuration for TUD-LBM.
This frozen dataclass is used for parsing, validation, and serialisation.
It never enters a JIT boundary and serves as the configuration container
for all simulation parameters including grid, collision, boundary conditions,
and output settings.
"""
# ── Simulation identity ──────────────────────────────────────
[docs]
sim_type: Literal[
"single_phase",
"multiphase",
"multiphase_wetting",
"multiphase_hysteresis",
"multiphase_hysteresis_chemical_step",
] = field(
default="single_phase",
metadata={CONFIG_SECTION: "identity"},
)
[docs]
simulation_name: str | None = None
# ── Lattice & grid ───────────────────────────────────────────
[docs]
lattice_type: str = "D2Q9"
[docs]
grid_shape: tuple[int, ...] = array_field(default=(64, 64))
# ── Time stepping ────────────────────────────────────────────
[docs]
nt: int = array_field(default=1000)
[docs]
tau: float = array_field(default=1.0)
# ── Collision ────────────────────────────────────────────────
[docs]
collision_scheme: str = array_field(default="bgk")
[docs]
k_diag: tuple[float, ...] | None = array_field(default=None)
# ── Boundary conditions (ONLY topology: which BC on which face) ──
[docs]
bc_config: dict[str, Any] | None = field(
default=None,
metadata={CONFIG_SECTION: "boundary_conditions"},
)
# ── Wetting model ──────────
[docs]
wetting_config: dict[str, Any] | None = array_field(default=None, section="wetting", nested_sweepable=True)
# ── Hysteresis model ───────
[docs]
hysteresis_config: dict[str, Any] | None = array_field(default=None, section="hysteresis", nested_sweepable=True)
# ── Chemical step model ───────
[docs]
chemical_step_config: dict[str, Any] | None = array_field(
default=None, section="chemical_step", nested_sweepable=True
)
# ── Forces (each force is its own field, named by physics) ───
[docs]
gravity_force: dict[str, Any] | None = array_field(default=None, section="gravity_force", nested_sweepable=True)
[docs]
electric_force: dict[str, Any] | None = array_field(default=None, section="electric_force", nested_sweepable=True)
[docs]
gravity_masked_force: dict[str, Any] | None = array_field(
default=None, section="gravity_masked_force", nested_sweepable=True
)
# ── Initialisation ───────────────────────────────────────────
[docs]
init_type: str = "standard"
[docs]
init_dir: str | None = None
[docs]
initialisation: dict[str, Any] = field(
default_factory=dict,
metadata={CONFIG_SECTION: "initialisation"},
)
# ── Output / IO ──────────────────────────────────────────────
[docs]
results_dir: str = field(default=BASE_RESULTS_DIR, metadata={CONFIG_SECTION: "output"})
[docs]
save_fields: list[str] | None = field(default=None, metadata={CONFIG_SECTION: "output"})
[docs]
plot_fields: list[str] | None = field(default=None, metadata={CONFIG_SECTION: "output"})
[docs]
animate_fields: list[str] | None = field(default=None, metadata={CONFIG_SECTION: "output"})
[docs]
output_dir: str | None = field(default=None, metadata={CONFIG_SECTION: "output"})
# ── Multiphase ───────────────────────────────────────────────
[docs]
eos: str | None = array_field(default=None, section="multiphase")
[docs]
kappa: float | None = array_field(default=None, section="multiphase")
[docs]
rho_l: float | None = array_field(default=None, section="multiphase")
[docs]
rho_v: float | None = array_field(default=None, section="multiphase")
[docs]
interface_width: int | None = array_field(default=None, section="multiphase")
[docs]
g: float | None = array_field(default=None, section="multiphase")
# ── Extra / extensible ───────────────────────────────────────
# Validation
[docs]
def __post_init__(self) -> None:
"""Validate and normalize configuration after initialization."""
self._normalize()
self._apply_defaults()
self._make_grid_shape_3d()
self._set_all_bcs()
self._validate_common()
if "multiphase" in self.sim_type:
self._validate_multiphase()
def _normalize(self) -> None:
object.__setattr__(self, "grid_shape", _normalize_sequence(self.grid_shape))
object.__setattr__(self, "output_format", _first_if_list(self.output_format))
if isinstance(self.output_format, str):
object.__setattr__(self, "output_format", self.output_format.lower())
def _apply_defaults(self) -> None:
if self.save_interval == 0:
object.__setattr__(self, "save_interval", self.nt // 10)
if self.bc_config is None:
object.__setattr__(
self,
"bc_config",
{
"top": "periodic",
"bottom": "periodic",
"left": "periodic",
"right": "periodic",
"front": "periodic",
"back": "periodic",
},
)
if self.hysteresis_config is not None and self.wetting_config is None:
object.__setattr__(
self,
"wetting_config",
{
"phi_left": 1.0,
"phi_right": 1.0,
"d_rho_left": 0.0,
"d_rho_right": 0.0,
},
)
def _make_grid_shape_3d(self) -> None:
"""Promote grid_shape to 3D by adding a singleton z-dimension."""
_target_dims = 3
if len(self.grid_shape) < _target_dims:
object.__setattr__(self, "grid_shape", self.grid_shape + (1,) * (_target_dims - len(self.grid_shape)))
def _set_all_bcs(self) -> None:
"""Set missing BCs in bc_config to 'periodic'."""
for edge in ("top", "bottom", "left", "right", "front", "back"):
if edge not in self.bc_config:
self.bc_config[edge] = "periodic"
def _validate_common(self) -> None:
"""Validate common simulation configuration parameters."""
self._validate_grid_shape()
self._validate_lattice()
self._validate_tau()
self._validate_time_steps()
self._validate_collision()
self._validate_forces()
self._validate_init()
self._validate_save_fields()
def _validate_forces(self) -> None:
"""Validate force configuration consistency."""
if self.gravity_force is not None and self.gravity_masked_force is not None:
msg = "Only one gravity force can be applied: set either gravity_force or gravity_masked_force, not both."
raise ValueError(msg)
def _validate_grid_shape(self) -> None:
"""Validate grid_shape dimensions."""
if len(self.grid_shape) < MIN_GRID_DIMENSIONS:
msg = f"grid_shape must have at least {MIN_GRID_DIMENSIONS} dimensions, got {len(self.grid_shape)}"
raise ValueError(msg)
if any(d <= 0 for d in self.grid_shape):
msg = f"All grid dimensions must be positive, got {self.grid_shape}"
raise ValueError(msg)
def _validate_lattice(self) -> None:
"""Validate lattice_type is supported."""
if self.lattice_type not in _valid_lattices():
valid = _valid_lattices()
msg = f"lattice_type must be one of {valid}, got '{self.lattice_type}'"
raise ValueError(msg)
def _validate_tau(self) -> None:
"""Validate tau for stability."""
if self.tau <= MIN_TAU_VALUE:
msg = f"tau must be > {MIN_TAU_VALUE} for stability, got {self.tau}"
raise ValueError(msg)
def _validate_time_steps(self) -> None:
"""Validate time stepping parameters."""
if self.nt <= 0:
msg = f"nt must be positive, got {self.nt}"
raise ValueError(msg)
_validate_nonnegative(self.save_interval, "save_interval")
_validate_nonnegative(self.skip_interval, "skip_interval")
def _validate_collision(self) -> None:
"""Validate collision scheme and parameters."""
valid_schemes = _valid_collision_schemes()
if self.collision_scheme not in valid_schemes:
msg = f"collision_scheme must be one of {sorted(valid_schemes)}, got '{self.collision_scheme}'"
raise ValueError(msg)
if self.collision_scheme == "mrt" and self.k_diag is None:
msg = "k_diag must be provided when using MRT collision scheme"
raise ValueError(msg)
def _validate_init(self) -> None:
"""Validate initialization parameters."""
if self.init_type == "init_from_file" and self.init_dir is None:
msg = "init_dir must be provided when init_type is 'init_from_file'"
raise ValueError(msg)
def _validate_save_fields(self) -> None:
"""Validate save_fields are valid."""
if self.save_fields is not None:
valid_fields = {"f", "rho", "u", "force", "force_ext", "h"}
invalid = set(self.save_fields) - valid_fields
if invalid:
msg = f"Invalid save_fields: {invalid}. Valid fields: {valid_fields}"
raise ValueError(msg)
def _validate_multiphase(self) -> None:
required = ("kappa", "rho_l", "rho_v", "interface_width", "eos")
for name in required:
if getattr(self, name) is None:
msg = f"'{name}' is required for multiphase simulations"
raise ValueError(msg)
_validate_positive(self.rho_l, "rho_l")
_validate_positive(self.rho_v, "rho_v")
if self.rho_l is not None and self.rho_v is not None and self.rho_l <= self.rho_v:
msg = f"rho_l ({self.rho_l}) must be greater than rho_v ({self.rho_v})"
raise ValueError(msg)
_validate_positive(self.kappa, "kappa")
_validate_positive(self.interface_width, "interface_width")
valid_eos = _valid_eos()
if self.eos not in valid_eos:
msg = f"eos must be one of {sorted(valid_eos)}, got '{self.eos}'"
raise ValueError(msg)
@property
[docs]
def is_single_phase(self) -> bool:
"""Check if simulation is single-phase."""
return self.sim_type == "single_phase"
@property
[docs]
def is_multiphase(self) -> bool:
"""Check if simulation is multiphase."""
return "multiphase" in self.sim_type
@property
[docs]
def force_enabled(self) -> bool:
"""Check if any force field is populated."""
return any(getattr(self, f.name) is not None for f in dataclasses.fields(self) if f.name.endswith("_force"))
# Serialisation
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert configuration to dictionary format."""
from dataclasses import asdict
d = asdict(self)
extra = d.pop("extra", {})
d.update(extra)
d["simulation_type"] = self.sim_type
return d
[docs]
def __repr__(self) -> str:
"""Return string representation of SimulationConfig."""
return (
f"SimulationConfig(\n"
f" sim_type={self.sim_type!r},\n"
f" grid_shape={self.grid_shape!r},\n"
f" lattice_type={self.lattice_type!r},\n"
f" tau={self.tau!r},\n"
f" nt={self.nt!r},\n"
f" collision_scheme={self.collision_scheme!r},\n"
f" init_type={self.init_type!r},\n"
f")"
)
[docs]
def get_array_eligible_fields() -> frozenset[str]:
"""Get names of fields eligible for array expansion in configuration sweeps."""
return frozenset(f.name for f in dataclasses.fields(SimulationConfig) if f.metadata.get(ARRAY_ELIGIBLE, False))
[docs]
def get_nested_sweepable_fields() -> frozenset[str]:
"""Return field names whose dict sub-keys may carry list sweep values."""
return frozenset(f.name for f in dataclasses.fields(SimulationConfig) if f.metadata.get(NESTED_SWEEPABLE, False))
[docs]
def get_fields_for_section(section: str) -> frozenset[str]:
"""Get field names belonging to a specific configuration section."""
return frozenset(
f.name
for f in dataclasses.fields(SimulationConfig)
if f.metadata.get(CONFIG_SECTION, "simulation_type") == section
)