tud_lbm.pipeline.io_callbacks ============================= .. py:module:: tud_lbm.pipeline.io_callbacks .. autoapi-nested-parse:: Host-callback-based I/O for TUD-LBM. Provides utilities for saving simulation snapshots during a ``jax.lax.scan`` time loop **without breaking the JIT trace**. Two strategies are supported: 1. **Post-hoc slicing** (default in :func:`~runner.run.run`): ``lax.scan`` returns the full trajectory; snapshots are selected after the scan completes. Simple but memory-heavy for long runs. 2. **Debug callback** (:func:`save_snapshot_callback`): Uses ``jax.debug.callback`` (ordered) to write data to disk at specific steps, keeping device memory constant. Suitable for production runs. Usage (strategy 2):: from runner.io_callbacks import make_save_callback # Outside jit: io_handler = SimulationIO(...) should_save, do_save = make_save_callback( io_handler, save_interval=100, skip_interval=0, ) # Inside scan body: def scan_body(state, _t): new_state = step_fn(state) do_save(new_state, new_state.t) return new_state, None Functions --------- .. autoapisummary:: tud_lbm.pipeline.io_callbacks.make_save_callback Module Contents --------------- .. py:function:: make_save_callback(io_handler: tud_lbm.io.SimulationIO, save_interval: int, skip_interval: int = 0, save_fields: tuple | None = None) -> collections.abc.Callable Build an I/O callback for use inside a ``lax.scan`` body. Returns a callable ``do_save(state, t)`` that can be placed in the scan body. Interval checking is performed on-device via ``jax.lax.cond``; callback only executes when conditions are met. :param io_handler: A :class:`~util.io.SimulationIO` instance. :param save_interval: How often to save. :param skip_interval: Steps to skip before first save. :param save_fields: Optional tuple of field names to save. :returns: ``do_save(state, t)`` — no return value; writes to disk as a side effect. Example:: do_save = make_save_callback(io_handler, save_interval=100) def scan_body(state, _t): new_state = step_fn(state) do_save(new_state, new_state.t) return new_state, None