tud_lbm.pipeline.io_callbacks

Host-callback-based I/O for TUD-LBM.

Provides utilities for saving simulation snapshots during a jax.lax.scan time loop without breaking the JIT trace.

Two strategies are supported:

  1. Post-hoc slicing (default in run()): lax.scan returns the full trajectory; snapshots are selected after the scan completes. Simple but memory-heavy for long runs.

  2. Debug callback (save_snapshot_callback()): Uses jax.debug.callback (ordered) to write data to disk at specific steps, keeping device memory constant. Suitable for production runs.

Usage (strategy 2):

from runner.io_callbacks import make_save_callback

# Outside jit:
io_handler = SimulationIO(...)
should_save, do_save = make_save_callback(
    io_handler, save_interval=100, skip_interval=0,
)

# Inside scan body:
def scan_body(state, _t):
    new_state = step_fn(state)
    do_save(new_state, new_state.t)
    return new_state, None

Functions

make_save_callback(→ collections.abc.Callable)

Build an I/O callback for use inside a lax.scan body.

Module Contents

tud_lbm.pipeline.io_callbacks.make_save_callback(io_handler: tud_lbm.io.SimulationIO, save_interval: int, skip_interval: int = 0, save_fields: tuple | None = None) collections.abc.Callable[source]

Build an I/O callback for use inside a lax.scan body.

Returns a callable do_save(state, t) that can be placed in the scan body. Interval checking is performed on-device via jax.lax.cond; callback only executes when conditions are met.

Parameters:
  • io_handler – A SimulationIO instance.

  • save_interval – How often to save.

  • skip_interval – Steps to skip before first save.

  • save_fields – Optional tuple of field names to save.

Returns:

do_save(state, t) — no return value; writes to disk as a side effect.

Example:

do_save = make_save_callback(io_handler, save_interval=100)

def scan_body(state, _t):
    new_state = step_fn(state)
    do_save(new_state, new_state.t)
    return new_state, None