Source code for tud_lbm.config.array_expansion

"""Array parameter expansion for parallel simulations.

Detects array-valued parameters in configuration and expands them into
multiple :class:`SimulationConfig` objects for parallel execution.

Example usage::

    from config.adapter_toml import TomlAdapter
    from config.array_expansion import detect_array_fields, expand_config

    # Load config (arrays are preserved in raw form)
    adapter = TomlAdapter()
    raw_config = adapter.load("config_parallel.toml", allow_arrays=True)

    # Detect which fields have arrays
    array_fields = detect_array_fields(raw_config)

    # Expand into multiple configs
    configs = expand_config(raw_config)
    # Returns: [SimulationConfig(...), SimulationConfig(...), ...]
"""

from __future__ import annotations
from dataclasses import dataclass
from itertools import product
from typing import TYPE_CHECKING
from typing import Any
from tud_lbm.config.simulation_config import SimulationConfig
from tud_lbm.config.simulation_config import get_array_eligible_fields
from tud_lbm.config.simulation_config import get_nested_sweepable_fields

if TYPE_CHECKING:
    from collections.abc import Iterator

# Constants for grid_shape validation
[docs] GRID_SHAPE_TUPLE_LEN = 2 # grid_shape is a (nx, ny) tuple
@dataclass(frozen=True)
[docs] class ArrayParameterSet: """Metadata about which fields had arrays in the original config. Attributes: field_names: Names of fields that contained arrays. array_values: Mapping from field name to the array of values. total_combinations: Total number of config combinations generated. """
[docs] field_names: frozenset[str]
[docs] array_values: dict[str, tuple[Any, ...]]
[docs] total_combinations: int
[docs] def detect_array_fields(_config: SimulationConfig) -> ArrayParameterSet | None: """Detect which fields in a config contain array values. Returns None if no arrays are present. Args: _config: A :class:`SimulationConfig` (should only contain scalars after full expansion; this is more for validation/inspection). Returns: :class:`ArrayParameterSet` describing detected arrays, or None if no arrays were found. """ # Note: After expansion, configs are scalar-only. This function is # primarily for inspecting the raw metadata before expansion. return None
# ── grid_shape helpers ──────────────────────────────────────────────────────── def _is_scalar_grid_shape(value: Any) -> bool: # noqa: ANN401 """Return True if *value* is a single (nx, ny) integer tuple.""" return isinstance(value, tuple) and len(value) == GRID_SHAPE_TUPLE_LEN and all(isinstance(x, int) for x in value) def _is_scalar_grid_shape_list(value: Any) -> bool: # noqa: ANN401 """Return True if *value* is a flat list of ints (not yet converted).""" return isinstance(value, list) and all(isinstance(x, int) for x in value) def _is_array_grid_shape(value: Any) -> bool: # noqa: ANN401 """Return True if *value* is a list/tuple of (nx, ny) tuples.""" return isinstance(value, (list, tuple)) and bool(value) and isinstance(value[0], tuple) def _is_array_value(value: Any, field_name: str = "") -> bool: # noqa: ANN401 """Check if a value should be treated as an array for expansion. Special case: grid_shape can be a 2-tuple of ints (scalar) or a list/tuple of tuples (array). We distinguish: - (100, 100) → scalar grid_shape - [(64, 64), (128, 128)] → array of grid_shapes """ if field_name == "grid_shape": return _is_array_grid_shape(value) return isinstance(value, (list, tuple)) and not isinstance(value, (str, bytes)) # ── nested-axes helpers ─────────────────────────────────────────────────────── def _collect_list_subkeys(field_name: str, nested: dict[str, Any]) -> dict[str, tuple[Any, ...]]: """Return dotted-path → values for every list-valued sub-key in *nested*. Args: field_name: Parent field name (used to build dotted keys). nested: Sub-dict of a sweepable nested field. Returns: Mapping of ``"field_name.sub_key"`` to tuple of values. """ return {f"{field_name}.{sk}": tuple(sv) for sk, sv in nested.items() if isinstance(sv, list) and sv} def _extract_nested_array_axes(config_dict: dict[str, Any]) -> dict[str, tuple[Any, ...]]: """Return dotted-path → values for array sub-keys inside nested sweepable fields. For example, if ``config_dict["gravity_force"] = {"force_g": 5e-7, "inclination_angle_deg": [50, 60]}``, this returns ``{"gravity_force.inclination_angle_deg": (50, 60)}``. Args: config_dict: Raw configuration dict. Returns: Mapping of ``"field.sub_key"`` to tuple of values for every list-valued sub-key found in the NESTED_SWEEPABLE_FIELDS entries. """ axes: dict[str, tuple[Any, ...]] = {} for field_name in get_nested_sweepable_fields(): nested = config_dict.get(field_name) if isinstance(nested, dict): axes.update(_collect_list_subkeys(field_name, nested)) return axes def _apply_dotted_key(result: dict[str, Any], key: str, value: Any) -> None: # noqa: ANN401 """Apply a single dotted-path key into *result* in-place. Splits ``"parent.sub_key"`` and merges *value* into the nested dict stored at ``result["parent"]``, copying it first to avoid mutating shared state. Args: result: Config dict being built. key: Dotted key of the form ``"parent.sub_key"``. value: Value to assign at the sub-key. """ parent, sub_key = key.split(".", 1) nested = dict(result.get(parent) or {}) nested[sub_key] = value result[parent] = nested def _apply_combo_to_dict( base_dict: dict[str, Any], combo_params: dict[str, Any], ) -> dict[str, Any]: """Return a copy of *base_dict* with *combo_params* applied. Top-level keys are set directly. Dotted-path keys (``"field.sub_key"``) are applied into the nested dict stored at ``base_dict["field"]``. Args: base_dict: Base scalar config dict (no list values). combo_params: Mapping of key (plain or dotted) → value for this combination. Returns: New dict suitable for ``SimulationConfig(**...)``. """ result = dict(base_dict) for key, value in combo_params.items(): if "." in key: _apply_dotted_key(result, key, value) else: result[key] = value return result # ── expansion internals ─────────────────────────────────────────────────────── @dataclass(frozen=True) class _ExpandResult: """Internal expansion artefacts shared by expand_config and enumerate_configs. Attributes: configs: Expanded list of :class:`SimulationConfig` objects. metadata: :class:`ArrayParameterSet` when arrays were found, else None. scalar_dict: Base scalar config dict (list values stripped). axis_keys: Ordered list of swept parameter names (plain or dotted). all_axes: Mapping from axis key to its tuple of values. """ configs: list[SimulationConfig] metadata: ArrayParameterSet | None scalar_dict: dict[str, Any] axis_keys: list[str] all_axes: dict[str, tuple[Any, ...]] def _collect_top_level_axes( config_dict: dict[str, Any], allow_arrays: bool, ) -> dict[str, tuple[Any, ...]]: """Collect and validate top-level array axes from *config_dict*. Iterates over array-eligible fields and returns those carrying list values. Raises :exc:`ValueError` for any array found when *allow_arrays* is False. Args: config_dict: Raw configuration dict. allow_arrays: If False, raise on any detected array. Returns: Mapping of eligible field name → tuple of values. Raises: ValueError: If an array is found and *allow_arrays* is False. """ eligible = get_array_eligible_fields() top_level_axes: dict[str, tuple[Any, ...]] = {} for key, value in config_dict.items(): if key not in eligible: continue if not _is_array_value(value, field_name=key): continue if not allow_arrays: msg = ( f"Array values found for field '{key}' but allow_arrays=False. Use allow_arrays=True or flatten config." ) raise ValueError(msg) top_level_axes[key] = tuple(value) return top_level_axes def _validate_nested_axes( nested_axes: dict[str, tuple[Any, ...]], allow_arrays: bool, ) -> None: """Raise :exc:`ValueError` if nested arrays are present but not allowed. Args: nested_axes: Mapping of dotted-path → values produced by :func:`_extract_nested_array_axes`. allow_arrays: If False and *nested_axes* is non-empty, raise. Raises: ValueError: If *nested_axes* is non-empty and *allow_arrays* is False. """ if nested_axes and not allow_arrays: first = next(iter(nested_axes)) msg = f"Array values found for nested field '{first}' but allow_arrays=False." raise ValueError(msg) def _strip_nested_field( field_name: str, nested_dict: dict[str, Any], nested_axes: dict[str, tuple[Any, ...]], ) -> dict[str, Any]: """Return *nested_dict* with list sub-keys that are swept axes removed. Args: field_name: Parent field name (used to build dotted keys). nested_dict: Sub-dict of a sweepable nested field. nested_axes: Set of dotted-path keys that are being swept. Returns: Copy of *nested_dict* without the swept sub-keys. """ return {sk: sv for sk, sv in nested_dict.items() if f"{field_name}.{sk}" not in nested_axes} def _build_scalar_dict( config_dict: dict[str, Any], top_level_axes: dict[str, tuple[Any, ...]], nested_axes: dict[str, tuple[Any, ...]], ) -> dict[str, Any]: """Build a scalar base dict by stripping swept array values. Top-level keys that are axes are omitted entirely. Nested sweepable dicts have their list sub-keys removed via :func:`_strip_nested_field`. Args: config_dict: Raw configuration dict. top_level_axes: Top-level fields identified as sweep axes. nested_axes: Dotted-path fields identified as nested sweep axes. Returns: Dict suitable for passing to ``SimulationConfig(**...)`` after a combo dict has been merged in. """ scalar_dict: dict[str, Any] = {} for k, v in config_dict.items(): if k in top_level_axes: continue if k in get_nested_sweepable_fields() and isinstance(v, dict): scalar_dict[k] = _strip_nested_field(k, v, nested_axes) else: scalar_dict[k] = v return scalar_dict def _build_configs_and_metadata( scalar_dict: dict[str, Any], axis_keys: list[str], all_axes: dict[str, tuple[Any, ...]], ) -> tuple[list[SimulationConfig], ArrayParameterSet]: """Build the Cartesian-product configs and accompanying metadata. Args: scalar_dict: Base scalar config dict (no swept list values). axis_keys: Ordered list of axis names (plain or dotted). all_axes: Mapping from axis key to its tuple of values. Returns: Tuple of ``(configs, metadata)`` covering all parameter combinations. """ axis_lists = [all_axes[k] for k in axis_keys] combinations = list(product(*axis_lists)) configs: list[SimulationConfig] = [ SimulationConfig(**_apply_combo_to_dict(scalar_dict, dict(zip(axis_keys, combo, strict=False)))) for combo in combinations ] metadata = ArrayParameterSet( field_names=frozenset(axis_keys), array_values=all_axes, total_combinations=len(combinations), ) return configs, metadata def _build_expand_result( config_dict: dict[str, Any], *, allow_arrays: bool, ) -> _ExpandResult: """Core expansion logic shared by :func:`expand_config` and :func:`enumerate_configs`. Args: config_dict: Raw configuration dict. allow_arrays: If False, raise ValueError if arrays are found. Returns: :class:`_ExpandResult` with all expansion artefacts. Raises: ValueError: If arrays are found but *allow_arrays* is False. """ top_level_axes = _collect_top_level_axes(config_dict, allow_arrays) nested_axes = _extract_nested_array_axes(config_dict) _validate_nested_axes(nested_axes, allow_arrays) all_axes: dict[str, tuple[Any, ...]] = {**top_level_axes, **nested_axes} if not all_axes: config = SimulationConfig(**config_dict) return _ExpandResult( configs=[config], metadata=None, scalar_dict=dict(config_dict), axis_keys=[], all_axes={}, ) scalar_dict = _build_scalar_dict(config_dict, top_level_axes, nested_axes) axis_keys = list(all_axes.keys()) configs, metadata = _build_configs_and_metadata(scalar_dict, axis_keys, all_axes) return _ExpandResult( configs=configs, metadata=metadata, scalar_dict=scalar_dict, axis_keys=axis_keys, all_axes=all_axes, ) # ── public API ────────────────────────────────────────────────────────────────
[docs] def expand_config( config_dict: dict[str, Any], *, allow_arrays: bool = True, ) -> tuple[list[SimulationConfig], ArrayParameterSet | None]: """Expand a config dict with array fields into multiple configs. Performs Cartesian product expansion over: * Top-level ``ARRAY_ELIGIBLE_FIELDS`` that carry a list value. * Sub-keys inside ``NESTED_SWEEPABLE_FIELDS`` dicts (``gravity_force``, ``electric_force``, ``wetting_config``, ``hysteresis_config``) that carry a list value. Args: config_dict: Raw configuration dict (typically from TomlAdapter before SimulationConfig instantiation). allow_arrays: If False, raise ValueError if arrays are found. Returns: ``(configs, metadata)`` where: - *configs*: List of :class:`SimulationConfig` objects (one per combination). If no arrays found, list contains single config. - *metadata*: :class:`ArrayParameterSet` if arrays were detected, else None. Raises: ValueError: If arrays are found but *allow_arrays* is False. TypeError: If array values contain incompatible types. """ result = _build_expand_result(config_dict, allow_arrays=allow_arrays) return result.configs, result.metadata
def _iter_combinations( result: _ExpandResult, ) -> Iterator[tuple[int, dict[str, Any], SimulationConfig]]: """Yield ``(index, parameters, config)`` for every axis combination in *result*. Args: result: Fully populated :class:`_ExpandResult` with at least one axis. Yields: ``(index, parameters, config)`` tuples. """ axis_lists = [result.all_axes[k] for k in result.axis_keys] for idx, combo in enumerate(product(*axis_lists)): parameters = dict(zip(result.axis_keys, combo, strict=False)) combo_dict = _apply_combo_to_dict(result.scalar_dict, parameters) yield idx, parameters, SimulationConfig(**combo_dict)
[docs] def enumerate_configs( config_dict: dict[str, Any], *, allow_arrays: bool = True, ) -> Iterator[tuple[int, dict[str, Any], SimulationConfig]]: """Yield (index, parameters, config) tuples for each expansion. *parameters* uses dotted-path keys for nested fields (e.g. ``"gravity_force.inclination_angle_deg"``). Args: config_dict: Raw configuration dict. allow_arrays: If False, raise ValueError if arrays are found. Yields: ``(index, parameters, config)`` tuples where: - *index*: 0-based position in the expansion. - *parameters*: Dict of parameter names and values for this combo. - *config*: The :class:`SimulationConfig`. """ result = _build_expand_result(config_dict, allow_arrays=allow_arrays) if result.metadata is None: yield 0, {}, result.configs[0] return yield from _iter_combinations(result)