"""Dynamic simulation state for TUD-LBM.
Provides :class:`State` and :class:`WettingState` — the pytree-compatible
carry objects used inside ``jax.lax.scan``.
Also provides state builder functions :func:`build_optional_fields` and
:func:`build_extra_state` that orchestrate state composition.
Public API::
from state import State, WettingState, build_optional_fields, build_extra_state
"""
from typing import Any
import jax.numpy as jnp
from tud_lbm.pipeline.setup import SimulationSetup
from tud_lbm.pipeline.state._extra_state import _build_extra_state
from tud_lbm.pipeline.state._extra_state import _update_extra_state
from tud_lbm.pipeline.state._optional_fields import _build_optional_fields
from tud_lbm.pipeline.state.state import State
from tud_lbm.pipeline.state.state import WettingState
__all__ = [
"State",
"WettingState",
"build_extra_state",
"build_optional_fields",
"update_extra_state",
]
[docs]
def build_optional_fields(
setup: SimulationSetup, nx: int, ny: int, nz: int, d: int
) -> tuple[jnp.ndarray | None, jnp.ndarray | None]:
"""Build the optional ``force`` and ``force_ext`` fields.
Returns zero-filled arrays for fields that will be written by the
step function, and ``None`` for fields that remain unused.
JAX's ``lax.scan`` requires the pytree structure to be constant
across iterations, so any field that transitions from ``None`` →
array must start as zeros.
Args:
setup: :class:`~setup.simulation_setup.SimulationSetup`.
nx: Grid size in x.
ny: Grid size in y.
nz: Grid size in z.
d: Lattice dimension (e.g. 2 for D2Q9).
Returns:
A tuple ``(force, force_ext)``:
* ``force``: Zeros ``(nx, ny, nz, 1, d)`` for multiphase runs,
``None`` otherwise.
* ``force_ext``: Zeros ``(nx, ny, nz, 1, d)`` when forces are active,
``None`` otherwise.
See Also:
:func:`state._optional_fields._build_optional_fields`
"""
return _build_optional_fields(setup, nx, ny, nz, d)