Source code for tud_lbm.pipeline.io_callbacks

"""Host-callback-based I/O for TUD-LBM.

Provides utilities for saving simulation snapshots during a
``jax.lax.scan`` time loop **without breaking the JIT trace**.

Two strategies are supported:

1.  **Post-hoc slicing** (default in :func:`~runner.run.run`):
    ``lax.scan`` returns the full trajectory; snapshots are selected
    after the scan completes.  Simple but memory-heavy for long runs.

2.  **Debug callback** (:func:`save_snapshot_callback`):
    Uses ``jax.debug.callback`` (ordered) to write data to disk at
    specific steps, keeping device memory constant.  Suitable for
    production runs.

Usage (strategy 2)::

    from runner.io_callbacks import make_save_callback

    # Outside jit:
    io_handler = SimulationIO(...)
    should_save, do_save = make_save_callback(
        io_handler, save_interval=100, skip_interval=0,
    )

    # Inside scan body:
    def scan_body(state, _t):
        new_state = step_fn(state)
        do_save(new_state, new_state.t)
        return new_state, None
"""

from __future__ import annotations
from typing import TYPE_CHECKING
import jax
import numpy as np

if TYPE_CHECKING:
    from collections.abc import Callable
    from tud_lbm.io import SimulationIO
    from tud_lbm.pipeline.state import State


def _state_to_numpy(state: State, fields: tuple | None = None, t: int | None = None) -> dict:
    """Convert a :class:`~state.state.State` pytree to a NumPy dict.

    Only includes non-``None`` array fields suitable for saving.
    Detects and raises on NaN values.

    Args:
        state: The state object to convert.
        fields: Optional tuple of field names to include. If None, includes all.
        t: Optional timestep for error messages.

    Raises:
        FloatingPointError: If NaN values are detected.
    """
    data = {
        k: np.asarray(v)
        for k, v in state._asdict().items()
        if v is not None and hasattr(v, "shape") and (fields is None or k in fields)
    }

    # Persist wetting scalars explicitly so plotting/analysis can read them from snapshots.
    # Note: wetting fields (ca_left, ca_right, cll_left, cll_right) are ALWAYS saved
    # because they are needed for analysis plots. They are scalar-valued metadata, not
    # large field arrays, so the storage overhead is negligible. Users can control which
    # analysis plots are created via config.plot_fields in the [output] section.
    wetting = getattr(state, "wetting", None)
    if wetting is not None:
        for key, value in wetting._asdict().items():
            data[key] = np.asarray(value)

    bad = [k for k, v in data.items() if np.isnan(v).any()]
    if bad:
        jax.debug.print("NaNs detected at t={t} in fields: {bad}", t=t, bad=bad)
    return data


[docs] def make_save_callback( io_handler: SimulationIO, save_interval: int, skip_interval: int = 0, save_fields: tuple | None = None, ) -> Callable: """Build an I/O callback for use inside a ``lax.scan`` body. Returns a callable ``do_save(state, t)`` that can be placed in the scan body. Interval checking is performed on-device via ``jax.lax.cond``; callback only executes when conditions are met. Args: io_handler: A :class:`~util.io.SimulationIO` instance. save_interval: How often to save. skip_interval: Steps to skip before first save. save_fields: Optional tuple of field names to save. Returns: ``do_save(state, t)`` — no return value; writes to disk as a side effect. Example:: do_save = make_save_callback(io_handler, save_interval=100) def scan_body(state, _t): new_state = step_fn(state) do_save(new_state, new_state.t) return new_state, None """ def _host_save(state: State, t: int) -> None: """Write state to disk; interval already checked on-device.""" it = int(t) data = _state_to_numpy(state, fields=save_fields, t=it) io_handler.save_data_step(it, data) def do_save(state: State, t: int) -> None: jax.lax.cond( (t > skip_interval) & (t % save_interval == 0), lambda _s, _t: jax.debug.callback(_host_save, state, t, ordered=True), lambda _s, _t: None, state, t, ) return do_save