Source code for tud_lbm.pipeline.setup

"""JAX-friendly SimulationSetup and ``build_setup`` factory.

:class:`SimulationSetup` is an immutable :class:`typing.NamedTuple`
(i.e. a valid JAX pytree) that holds everything needed by the jitted
step function.  It is built from a validated
:class:`~config.simulation_config.SimulationConfig` via
:func:`build_setup`.

Design rationale
~~~~~~~~~~~~~~~~
* **Closed over, not ``static_argnums``:**  ``SimulationSetup`` is
  captured by ``functools.partial`` (or a closure) when constructing
  the scan body.  JAX treats it as a compile-time constant and caches
  the compilation.  If the setup changes, a new compilation is
  triggered — the correct behaviour.
* **No mutable class instances:**  All fields are Python scalars,
  tuples, or ``jax.Array`` values.  No operator *objects* live here.

Usage::

    from config.simulation_config import SimulationConfig
    from setup.simulation_setup import build_setup

    cfg = SimulationConfig(grid_shape=(64, 64), tau=0.8)
    setup = build_setup(cfg)
"""

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import NamedTuple
from typing import cast
from tud_lbm.lattice.lattice import Lattice
from tud_lbm.lattice.lattice import build_lattice
from tud_lbm.operators.differential import build_diff_ops
from tud_lbm.operators.force import ForceSetup
from tud_lbm.operators.force import build_forces

if TYPE_CHECKING:
    import jax.numpy as jnp
    from tud_lbm.config.simulation_config import SimulationConfig
    from tud_lbm.operators.boundary import BCMasks
    from tud_lbm.operators.macroscopic import MultiphaseParams
    from tud_lbm.operators.protocols import BoundaryOperator
    from tud_lbm.operators.protocols import CollisionOperator
    from tud_lbm.operators.protocols import DifferentialOperator
    from tud_lbm.operators.protocols import EquilibriumOperator
    from tud_lbm.operators.protocols import ExtraStatePlugin
    from tud_lbm.operators.protocols import HysteresisOperator
    from tud_lbm.operators.protocols import InitialPopulationOperator
    from tud_lbm.operators.protocols import MacroscopicOperator
    from tud_lbm.operators.protocols import StepOperator
    from tud_lbm.operators.protocols import StreamingOperator


[docs] class SimulationSetup(NamedTuple): """Immutable operator container — only what the jitted step needs. ``SimulationSetup`` stores built artifacts (operators, masks, closures) and physics scalars that the step function reads at JIT time. IO and initialisation metadata live on the original :class:`~config.simulation_config.SimulationConfig`, accessible via the :attr:`config` reference. Attributes: config: The original :class:`SimulationConfig` — for IO, init, and any metadata that does not enter the JIT boundary. lattice: The :class:`~setup.lattice.Lattice` pytree. grid_shape: Spatial dimensions, e.g. ``(64, 64)``. tau: Relaxation time (> 0.5). collision_scheme: Name of the collision model (``"bgk"`` / ``"mrt"``). k_diag: MRT relaxation rates (``None`` for BGK). bc_masks: Pre-computed boundary-condition masks (:class:`BCMasks`). forces: Pre-built force setup (:class:`~operators.force.ForceSetup`) containing specs and source-term callable, or ``None`` if no forces are active. multiphase_params: ``None`` for single-phase runs. gradient_standard: Standard gradient ``∇μ`` (chemical potential). Always used for chemical-potential gradient. Never wetting-corrected. gradient_density: Density gradient ``∇ρ`` used in source term. Wetting-corrected when applicable. laplacian_density: Laplacian of density ``∇²ρ`` in chemical-potential computation. Wetting-corrected when applicable. gradient_density_wetting: Parametric density gradient for hysteresis optimisation. Populated for hysteresis runs. Signature: ``(grid, phi_l, phi_r, d_rho_l, d_rho_r) -> result``. ``None`` for non-hysteresis cases. laplacian_density_wetting: Parametric Laplacian of density for hysteresis optimisation. Populated for hysteresis runs. Signature: ``(grid, phi_l, phi_r, d_rho_l, d_rho_r) -> result``. ``None`` for non-hysteresis cases. step_fn: The unbound step operator resolved from the registry, implementing :class:`~operators.protocols.StepOperator`. Signature: ``(setup, state) → state_next``. wetting_fn: The hysteresis operator for updating wetting state, implementing :class:`~operators.protocols.HysteresisOperator`. Built when ``hysteresis_config`` is present; ``None`` otherwise. extra_state_plugins: Active plugin tuple used to initialise and update operation-specific extra state (e.g. electric potential, wetting state). collision_fn: Pre-built collision operator, resolved at setup time. equilibrium_fn: Pre-built equilibrium operator, resolved at setup time. macroscopic_fn: Pre-built macroscopic operator, resolved at setup time. streaming_fn: Pre-built streaming operator, resolved at setup time. bc_fn: Pre-built boundary-condition operator, resolved at setup time. initial_f_fn: Pre-built initial population function, resolved at setup time. """ # ── Core references ──
[docs] config: SimulationConfig
[docs] lattice: Lattice
# ── Physics scalars (read at runtime inside JIT) ──
[docs] grid_shape: tuple[int, ...]
[docs] tau: float
[docs] collision_scheme: str
[docs] k_diag: tuple[float, ...] | None = None
# ── Pre-built operators ──
[docs] bc_masks: BCMasks | None = None
[docs] forces: ForceSetup | None = None
[docs] multiphase_params: MultiphaseParams | None = None
# ── Differential operator closures (pre-built) ──
[docs] gradient_standard: DifferentialOperator | None = None
[docs] gradient_density: DifferentialOperator | None = None
[docs] laplacian_density: DifferentialOperator | None = None
[docs] gradient_density_wetting: DifferentialOperator | None = None
[docs] laplacian_density_wetting: DifferentialOperator | None = None
# ── Step function (unbound: (setup, state) -> State) ──
[docs] step_fn: StepOperator | None = None
[docs] wetting_fn: HysteresisOperator | None = None
[docs] extra_state_plugins: tuple[ExtraStatePlugin, ...] = ()
# ── Pre-built operator closures (resolved at setup time) ──
[docs] collision_fn: CollisionOperator | None = None
[docs] equilibrium_fn: EquilibriumOperator | None = None
[docs] macroscopic_fn: MacroscopicOperator[..., tuple[jnp.ndarray, ...]] | None = None
[docs] streaming_fn: StreamingOperator | None = None
[docs] bc_fn: BoundaryOperator | None = None
[docs] initial_f_fn: InitialPopulationOperator[..., jnp.ndarray] | None = None
# ── Main factory ─────────────────────────────────────────────────────
[docs] def build_setup(config: SimulationConfig) -> SimulationSetup: """Construct a JAX-friendly :class:`SimulationSetup` from a config. Args: config: A validated :class:`SimulationConfig`. Returns: An immutable :class:`SimulationSetup` NamedTuple ready for the jitted step function. Raises: ValueError: If wetting configuration is present but sim_type is not "multiphase". """ # ── Validation: wetting config requires multiphase sim_type ── if config.wetting_config is not None and "multiphase" not in config.sim_type: msg = ( f"Wetting configuration present but sim_type is '{config.sim_type}'. " "Wetting requires sim_type = 'multiphase'. " "Wetting is an addon to multiphase simulations, detected by the presence of [wetting] config." ) raise ValueError( msg, ) # Import here to avoid circular import issues at module level from tud_lbm.operators.boundary import build_bc from tud_lbm.operators.boundary import build_bc_masks from tud_lbm.operators.collision import build_collision_fn from tud_lbm.operators.equilibrium import build_equilibrium_fn from tud_lbm.operators.initialise import build_initialise_fn from tud_lbm.operators.macroscopic import build_macroscopic_fn from tud_lbm.operators.macroscopic import build_multiphase_params from tud_lbm.operators.step import build_step_fn from tud_lbm.operators.streaming import build_streaming_fn from tud_lbm.operators.wetting import build_wetting_fn from tud_lbm.registry import get_operators lattice = build_lattice(config.lattice_type) bc_masks = build_bc_masks(tuple(config.grid_shape)) # Build multiphase params if applicable (multiphase runs with optional wetting) mp_params = build_multiphase_params(config) if "multiphase" in config.sim_type else None # Build force specs forces = build_forces(config, tuple(config.grid_shape), lattice) # Build differential operators (returns 5-tuple: standard, density, laplacian, raw_density, raw_laplacian) gradient_standard, gradient_density, laplacian_density, gradient_density_wetting, laplacian_density_wetting = ( build_diff_ops(config, mp_params, lattice) ) # Resolve step operator from registry step_fn = build_step_fn(config.sim_type) # Build operator closures (pre-resolved at setup time) collision_fn = build_collision_fn(config.collision_scheme) equilibrium_fn = build_equilibrium_fn("wb") streaming_fn = build_streaming_fn("standard") macroscopic_fn = ( build_macroscopic_fn(mp_params.eos) # EOS-aware for multiphase if "multiphase" in config.sim_type else build_macroscopic_fn("standard") # single-phase ) bc_fn = build_bc(config.bc_config, lattice) # Build wetting function for hysteresis-capable runs. wetting_fn = None if config.hysteresis_config is not None: wetting_scheme = ( "chemical_step_hysteresis" if config.sim_type == "multiphase_hysteresis_chemical_step" else "hysteresis" ) wetting_fn = build_wetting_fn(wetting_scheme) extra_state_plugins = tuple( cast("ExtraStatePlugin", entry.target) for entry in sorted(get_operators("extra_state").values(), key=lambda e: e.name) if cast("ExtraStatePlugin", entry.target).is_active(config) ) def _initial_f_fn(init_kwargs: dict | None = None) -> jnp.ndarray: kw: dict = {} if mp_params is not None: kw.update(rho_l=mp_params.rho_l, rho_v=mp_params.rho_v, interface_width=mp_params.interface_width) kw.update(config.initialisation) if init_kwargs: kw.update(init_kwargs) if config.init_type == "init_from_file" and "npz_path" not in kw and config.init_dir is not None: kw["npz_path"] = config.init_dir return build_initialise_fn(config.init_type)(config.grid_shape, lattice, **kw) return SimulationSetup( config=config, lattice=lattice, grid_shape=tuple(config.grid_shape), tau=config.tau, collision_scheme=config.collision_scheme, k_diag=config.k_diag, bc_masks=bc_masks, forces=forces, multiphase_params=mp_params, gradient_standard=gradient_standard, gradient_density=gradient_density, laplacian_density=laplacian_density, gradient_density_wetting=gradient_density_wetting, laplacian_density_wetting=laplacian_density_wetting, step_fn=step_fn, wetting_fn=wetting_fn, extra_state_plugins=extra_state_plugins, collision_fn=collision_fn, equilibrium_fn=equilibrium_fn, macroscopic_fn=macroscopic_fn, streaming_fn=streaming_fn, bc_fn=bc_fn, initial_f_fn=_initial_f_fn, )