tud_lbm.pipeline.io_callbacks
Host-callback-based I/O for TUD-LBM.
Provides utilities for saving simulation snapshots during a
jax.lax.scan time loop without breaking the JIT trace.
Two strategies are supported:
Post-hoc slicing (default in
run()):lax.scanreturns the full trajectory; snapshots are selected after the scan completes. Simple but memory-heavy for long runs.Debug callback (
save_snapshot_callback()): Usesjax.debug.callback(ordered) to write data to disk at specific steps, keeping device memory constant. Suitable for production runs.
Usage (strategy 2):
from runner.io_callbacks import make_save_callback
# Outside jit:
io_handler = SimulationIO(...)
should_save, do_save = make_save_callback(
io_handler, save_interval=100, skip_interval=0,
)
# Inside scan body:
def scan_body(state, t):
new_state = step_fn(state)
do_save(new_state, t)
return new_state, None
Functions
|
Build an I/O callback for use inside a |
Module Contents
- tud_lbm.pipeline.io_callbacks.make_save_callback(io_handler: tud_lbm.io.SimulationIO, save_interval: int, skip_interval: int = 0, save_fields: tuple | None = None) collections.abc.Callable[source]
Build an I/O callback for use inside a
lax.scanbody.Returns a callable
do_save(state, t)that can be placed in the scan body. Interval checking is performed on-device viajax.lax.cond; callback only executes when conditions are met.- Parameters:
io_handler – A
SimulationIOinstance.save_interval – How often to save.
skip_interval – Steps to skip before first save.
save_fields – Optional tuple of field names to save.
- Returns:
do_save(state, t)— no return value; writes to disk as a side effect.
Example:
do_save = make_save_callback(io_handler, save_interval=100) def scan_body(state, t): new_state = step_fn(state) do_save(new_state, t) return new_state, None