tud_lbm.pipeline.runner
Functional lax.scan-based simulation runner for TUD-LBM.
Provides run() — a thin wrapper around jax.lax.scan that
replaces the Python for-loop in the legacy time stepper.
Two execution modes are supported:
- In-memory trajectory (default)
lax.scanreturns the full trajectory; snapshots are selected after the scan completes via save_interval. Simple but memory-heavy for long runs.- Streaming I/O (when io_handler is supplied)
Uses
jax.debug.callbackto write snapshots to disk at the requested save_interval. Device memory stays constant because no trajectory is accumulated. The returned trajectory isNone.
Note
Host callbacks are non-differentiable. The streaming I/O path is intended for forward simulations only. If you need gradients through the time loop, use the in-memory trajectory mode.
Usage:
from setup.simulation_setup import build_setup
from runner.run import run, init_state
# In-memory (short runs / debugging)
setup = build_setup(config)
state = init_state(setup)
final, trajectory = run(setup, state, nt=1000)
# Streaming to disk (production)
from util.io import SimulationIO
io = SimulationIO(base_dir=config.results_dir,
config=config.to_dict(),
simulation_name=config.simulation_name)
final, _ = run(setup, state, nt=config.nt,
save_interval=config.save_interval,
io_handler=io)
Functions
|
Create an initial |
|
Run nt steps via |
Module Contents
- tud_lbm.pipeline.runner.init_state(setup: setup.SimulationSetup, *, f: jax.numpy.ndarray | None = None, init_kwargs: dict | None = None) tud_lbm.pipeline.state.state.State[source]
Create an initial
Statefor the given setup.If f is not supplied, the population distribution is initialised using
setup.initial_f_fn, which is built at setup time frominit_type. For"standard"this is the rest equilibrium (f_i = w_i); for multiphase types it produces a tanh density profile at equilibrium. Remaining state arrays are initialised with zeros so that the pytree structure stays constant acrosslax.scaniterations (JAX requires identical carry structure).- Parameters:
setup –
SimulationSetup.f – Optional pre-computed initial distribution function, shape
(nx, ny, nz, q, 1).init_kwargs – Extra keyword arguments forwarded to the initialisation function (e.g.
density,rho_l,rho_v,interface_width,npz_path).
- Returns:
A
Stateready to be passed torun().
- tud_lbm.pipeline.runner.run(setup: setup.SimulationSetup, initial_state: tud_lbm.pipeline.state.state.State, nt: int | None = None, save_interval: int = 1, io_handler: tud_lbm.io.SimulationIO | None = None, skip_interval: int = 0, save_fields: tuple[str, Ellipsis] | None = None) tuple[tud_lbm.pipeline.state.state.State, tud_lbm.pipeline.state.state.State | None][source]
Run nt steps via
jax.lax.scan.- Parameters:
setup –
SimulationSetup.initial_state – Starting
State.nt – Number of time steps. Defaults to
setup.config.nt.save_interval – Snapshot frequency (default: 1 = every step).
io_handler – Optional
SimulationIO. When supplied, snapshots are streamed to disk via host callbacks and the returned trajectory isNone.skip_interval – Number of initial steps to skip before saving (only used with io_handler).
save_fields – Subset of field names to write, e.g.
("rho", "u").Nonemeans all fields. Only used with io_handler.
- Returns:
(final_state, trajectory)No io_handler — trajectory is a
Statepytree with a leading axis of length nt (or nt / save_interval).With io_handler — trajectory is
None; data has been written toio_handler.data_dir.