Source code for tud_lbm.pipeline.runner

"""Functional ``lax.scan``-based simulation runner for TUD-LBM.

Provides :func:`run` — a thin wrapper around ``jax.lax.scan`` that
replaces the Python ``for``-loop in the legacy time stepper.

Two execution modes are supported:

**In-memory trajectory** (default)
    ``lax.scan`` returns the full trajectory; snapshots are selected
    after the scan completes via *save_interval*.  Simple but
    memory-heavy for long runs.

**Streaming I/O** (when *io_handler* is supplied)
    Uses ``jax.debug.callback`` to write snapshots to disk at the
    requested *save_interval*.  Device memory stays constant because
    no trajectory is accumulated.  The returned *trajectory* is
    ``None``.

.. note::

   Host callbacks are non-differentiable.  The streaming I/O path
   is intended for forward simulations only.  If you need gradients
   through the time loop, use the in-memory trajectory mode.

Usage::

    from setup.simulation_setup import build_setup
    from runner.run import run, init_state

    # In-memory (short runs / debugging)
    setup = build_setup(config)
    state = init_state(setup)
    final, trajectory = run(setup, state, nt=1000)

    # Streaming to disk (production)
    from util.io import SimulationIO
    io = SimulationIO(base_dir=config.results_dir,
                      config=config.to_dict(),
                      simulation_name=config.simulation_name)
    final, _ = run(setup, state, nt=config.nt,
                   save_interval=config.save_interval,
                   io_handler=io)
"""

from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
from tud_lbm.pipeline.state.state import State

if TYPE_CHECKING:
    from setup import SimulationSetup
    from tud_lbm.config import SimulationConfig
    from tud_lbm.io import SimulationIO


# ── State initialisation ─────────────────────────────────────────────


def _t_from_snapshot(config: SimulationConfig) -> jnp.ndarray:
    """Infer the starting timestep from an ``init_from_file`` snapshot name.

    Expected filename format is ``timestep_{N}.npz``. Any non-conforming
    name, parse failure, or non-file-based init type falls back to ``0``.
    """
    if config.init_type != "init_from_file" or config.init_dir is None:
        return jnp.array(0)

    stem = Path(config.init_dir).stem
    if not stem.startswith("timestep_"):
        return jnp.array(0)

    step_str = stem.removeprefix("timestep_")
    if not step_str.isdigit():
        return jnp.array(0)

    return jnp.array(int(step_str))


[docs] def init_state( setup: SimulationSetup, *, f: jnp.ndarray | None = None, init_kwargs: dict | None = None, ) -> State: """Create an initial :class:`State` for the given setup. If *f* is not supplied, the population distribution is initialised using ``setup.initial_f_fn``, which is built at setup time from ``init_type``. For ``"standard"`` this is the rest equilibrium (``f_i = w_i``); for multiphase types it produces a tanh density profile at equilibrium. Remaining state arrays are initialised with zeros so that the pytree structure stays constant across ``lax.scan`` iterations (JAX requires identical carry structure). Args: setup: :class:`~setup.simulation_setup.SimulationSetup`. f: Optional pre-computed initial distribution function, shape ``(nx, ny, nz, q, 1)``. init_kwargs: Extra keyword arguments forwarded to the initialisation function (e.g. ``density``, ``rho_l``, ``rho_v``, ``interface_width``, ``npz_path``). Returns: A :class:`State` ready to be passed to :func:`run`. """ from tud_lbm.pipeline.state import build_extra_state from tud_lbm.pipeline.state import build_optional_fields lattice = setup.lattice nx, ny, nz = setup.grid_shape[0], setup.grid_shape[1], setup.grid_shape[2] if f is None: f = setup.initial_f_fn(init_kwargs) rho = jnp.sum(f, axis=-2, keepdims=True) u = jnp.zeros((nx, ny, nz, 1, lattice.d)) t = _t_from_snapshot(setup.config) force, force_ext = build_optional_fields(setup, nx, ny, nz, lattice.d) extra_state = build_extra_state(setup) return State( f=f, rho=rho, u=u, t=t, force=force, force_ext=force_ext, **extra_state, )
# ── Functional run ───────────────────────────────────────────────────
[docs] def run( setup: SimulationSetup, initial_state: State, nt: int | None = None, save_interval: int = 1, io_handler: SimulationIO | None = None, skip_interval: int = 0, save_fields: tuple[str, ...] | None = None, ) -> tuple[State, State | None]: """Run *nt* steps via ``jax.lax.scan``. Args: setup: :class:`~setup.simulation_setup.SimulationSetup`. initial_state: Starting :class:`~state.state.State`. nt: Number of time steps. Defaults to ``setup.config.nt``. save_interval: Snapshot frequency (default: 1 = every step). io_handler: Optional :class:`~util.io.SimulationIO`. When supplied, snapshots are streamed to disk via host callbacks and the returned *trajectory* is ``None``. skip_interval: Absolute simulation-time threshold; steps with ``state.t <= skip_interval`` are not saved. For a fresh run this equals the number of initial steps to skip. For a resumed ``init_from_file`` run, compare against the *absolute* timestep, not the number of new steps (only used with *io_handler*). save_fields: Subset of field names to write, e.g. ``("rho", "u")``. ``None`` means all fields. Only used with *io_handler*. Returns: ``(final_state, trajectory)`` * **No io_handler** — *trajectory* is a :class:`State` pytree with a leading axis of length *nt* (or *nt / save_interval*). * **With io_handler** — *trajectory* is ``None``; data has been written to ``io_handler.data_dir``. """ if nt is None: nt = setup.config.nt # ── Streaming I/O mode ─────────────────────────────────────── if io_handler is not None: from tud_lbm.pipeline.io_callbacks import _state_to_numpy from tud_lbm.pipeline.io_callbacks import make_save_callback do_save = make_save_callback( io_handler, save_interval=save_interval, skip_interval=skip_interval, save_fields=save_fields, ) @jax.jit def scan_body_io(state: State, _t: int) -> tuple[State, None]: new_state = setup.step_fn(setup, state) do_save(new_state, new_state.t) return new_state, None final_state, _ = jax.lax.scan( scan_body_io, initial_state, jnp.arange(nt), ) final_t = int(final_state.t) io_handler.save_data_step( final_t, _state_to_numpy(final_state, fields=save_fields, t=final_t), ) return final_state, None # ── In-memory trajectory mode ──────────────────────────────── @jax.jit def scan_body(state: State, _t: int) -> tuple[State, State]: new_state = setup.step_fn(setup, state) return new_state, new_state final_state, trajectory = jax.lax.scan( scan_body, initial_state, jnp.arange(nt), ) if save_interval > 1: idx = jnp.arange(0, nt, save_interval) trajectory = jax.tree.map(lambda x: x[idx], trajectory) return final_state, trajectory