tud_lbm.pipeline.runner

Functional lax.scan-based simulation runner for TUD-LBM.

Provides run() — a thin wrapper around jax.lax.scan that replaces the Python for-loop in the legacy time stepper.

Two execution modes are supported:

In-memory trajectory (default)

lax.scan returns the full trajectory; snapshots are selected after the scan completes via save_interval. Simple but memory-heavy for long runs.

Streaming I/O (when io_handler is supplied)

Uses jax.debug.callback to write snapshots to disk at the requested save_interval. Device memory stays constant because no trajectory is accumulated. The returned trajectory is None.

Note

Host callbacks are non-differentiable. The streaming I/O path is intended for forward simulations only. If you need gradients through the time loop, use the in-memory trajectory mode.

Usage:

from setup.simulation_setup import build_setup
from runner.run import run, init_state

# In-memory (short runs / debugging)
setup = build_setup(config)
state = init_state(setup)
final, trajectory = run(setup, state, nt=1000)

# Streaming to disk (production)
from util.io import SimulationIO
io = SimulationIO(base_dir=config.results_dir,
                  config=config.to_dict(),
                  simulation_name=config.simulation_name)
final, _ = run(setup, state, nt=config.nt,
               save_interval=config.save_interval,
               io_handler=io)

Functions

init_state(→ tud_lbm.pipeline.state.state.State)

Create an initial State for the given setup.

run(→ tuple[tud_lbm.pipeline.state.state.State, ...)

Run nt steps via jax.lax.scan.

Module Contents

tud_lbm.pipeline.runner.init_state(setup: setup.SimulationSetup, *, f: jax.numpy.ndarray | None = None, init_kwargs: dict | None = None) tud_lbm.pipeline.state.state.State[source]

Create an initial State for the given setup.

If f is not supplied, the population distribution is initialised using setup.initial_f_fn, which is built at setup time from init_type. For "standard" this is the rest equilibrium (f_i = w_i); for multiphase types it produces a tanh density profile at equilibrium. Remaining state arrays are initialised with zeros so that the pytree structure stays constant across lax.scan iterations (JAX requires identical carry structure).

Parameters:
  • setupSimulationSetup.

  • f – Optional pre-computed initial distribution function, shape (nx, ny, nz, q, 1).

  • init_kwargs – Extra keyword arguments forwarded to the initialisation function (e.g. density, rho_l, rho_v, interface_width, npz_path).

Returns:

A State ready to be passed to run().

tud_lbm.pipeline.runner.run(setup: setup.SimulationSetup, initial_state: tud_lbm.pipeline.state.state.State, nt: int | None = None, save_interval: int = 1, io_handler: tud_lbm.io.SimulationIO | None = None, skip_interval: int = 0, save_fields: tuple[str, Ellipsis] | None = None) tuple[tud_lbm.pipeline.state.state.State, tud_lbm.pipeline.state.state.State | None][source]

Run nt steps via jax.lax.scan.

Parameters:
  • setupSimulationSetup.

  • initial_state – Starting State.

  • nt – Number of time steps. Defaults to setup.config.nt.

  • save_interval – Snapshot frequency (default: 1 = every step).

  • io_handler – Optional SimulationIO. When supplied, snapshots are streamed to disk via host callbacks and the returned trajectory is None.

  • skip_interval – Absolute simulation-time threshold; steps with state.t <= skip_interval are not saved. For a fresh run this equals the number of initial steps to skip. For a resumed init_from_file run, compare against the absolute timestep, not the number of new steps (only used with io_handler).

  • save_fields – Subset of field names to write, e.g. ("rho", "u"). None means all fields. Only used with io_handler.

Returns:

(final_state, trajectory)

  • No io_handlertrajectory is a State pytree with a leading axis of length nt (or nt / save_interval).

  • With io_handlertrajectory is None; data has been written to io_handler.data_dir.