Source code for tud_lbm.operators.boundary

"""Boundary condition operators — composite builder.

Public API: build_bc(), BCMasks, build_bc_masks

Implementation modules (_bounce_back.py, _periodic.py, _symmetry.py) are internal.

Example:
    from operators.boundary import build_bc, build_bc_masks

    bc_fn = build_bc({"top": "symmetry", "bottom": "bounce-back"}, lattice)
    masks = build_bc_masks((64, 64))
    f = bc_fn(f_stream, f_col, masks)
"""

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from typing import NamedTuple
import jax.numpy as jnp
from tud_lbm.operators.boundary import _bounce_back as _bb  # noqa: F401
from tud_lbm.operators.boundary import _periodic as _per  # noqa: F401
from tud_lbm.operators.boundary import _symmetry as _sym  # noqa: F401
from tud_lbm.registry import get_operators

if TYPE_CHECKING:
    from collections.abc import Callable
    from tud_lbm.lattice.lattice import Lattice
    from tud_lbm.operators.protocols import BoundaryOperator


[docs] class BCMasks(NamedTuple): """Pre-computed boundary-condition masks — valid JAX pytree. Each mask is a boolean ``jax.Array`` of shape ``(nx, ny, nz, 1, 1)`` that is ``True`` on the corresponding edge row/column/plane. Attributes: top: Mask for the top boundary (y = ny-1). bottom: Mask for the bottom boundary (y = 0). left: Mask for the left boundary (x = 0). right: Mask for the right boundary (x = nx-1). front: Mask for the front boundary (z = nz-1). back: Mask for the back boundary (z = 0). """
[docs] top: jnp.ndarray
[docs] bottom: jnp.ndarray
[docs] left: jnp.ndarray
[docs] right: jnp.ndarray
[docs] front: jnp.ndarray | None = None
[docs] back: jnp.ndarray | None = None
[docs] def build_bc_masks( grid_shape: tuple[int, ...], ) -> BCMasks: """Construct pre-computed boundary-condition masks. Each mask is a boolean array of shape ``(nx, ny, nz, 1, 1)`` that is ``True`` on the corresponding edge row/column/plane. Args: grid_shape: Spatial dimensions ``(nx, ny, nz)``. Returns: A :class:`BCMasks` NamedTuple with 3D-shaped masks for all 6 faces. """ nx, ny, nz = grid_shape[:3] if len(grid_shape) >= 3 else (*grid_shape[:2], 1) # noqa: PLR2004 # Create 3D masks for x-faces (left/right) top = jnp.zeros((nx, ny, nz, 1, 1), dtype=bool).at[:, -1, :].set(True) bottom = jnp.zeros((nx, ny, nz, 1, 1), dtype=bool).at[:, 0, :].set(True) left = jnp.zeros((nx, ny, nz, 1, 1), dtype=bool).at[0, :, :].set(True) right = jnp.zeros((nx, ny, nz, 1, 1), dtype=bool).at[-1, :, :].set(True) # Create 3D masks for z-faces (front/back) — optional for 3D front = jnp.zeros((nx, ny, nz, 1, 1), dtype=bool).at[:, :, -1].set(True) if nz > 1 else None back = jnp.zeros((nx, ny, nz, 1, 1), dtype=bool).at[:, :, 0].set(True) if nz > 1 else None return BCMasks(top=top, bottom=bottom, left=left, right=right, front=front, back=back)
def _get_bc_dispatch() -> dict[str, Callable]: """Build the BC dispatch dict from the global registry.""" bc_ops = get_operators("boundary_condition") return {name: entry.target for name, entry in bc_ops.items()}
[docs] def build_bc( bc_config: dict[str, Any] | None, lattice: Lattice, ) -> BoundaryOperator: """Build a composite BC closure from a bc_config dict. The returned function applies boundary conditions in order: bottom → top → left → right. Edges not present in *bc_config* or mapped to ``"periodic"`` are no-ops. Args: bc_config: Mapping ``{_edge: bc_type, ...}``, e.g. ``{"top": "symmetry", "bottom": "bounce-back", "left": "periodic", "right": "periodic"}``. ``None`` means all-periodic. lattice: :class:`~setup.lattice.Lattice`. Returns: A callable satisfying the BoundaryOperator protocol, ``bc_fn(f_stream, f_col, bc_masks) → f``. """ if bc_config is None: bc_config = {} bc_dispatch = _get_bc_dispatch() # Pre-compute the list of (_edge, bc_fn) pairs at build time. # Only include edges that need an operation (skip periodic / unknown). _edge_order = ("bottom", "top", "left", "right") ops = [] for edge in _edge_order: bc_type = bc_config.get(edge, "periodic") fn = bc_dispatch.get(bc_type) if fn is not None and bc_type != "periodic": ops.append((edge, fn)) def bc_fn( f_streamed: jnp.ndarray, f_collision: jnp.ndarray, bc_masks: Any, # noqa: ANN401, ARG001 ) -> jnp.ndarray: """Apply all boundary conditions in sequence. Args: f_streamed: Post-streaming populations. f_collision: Post-collision populations. bc_masks: :class:`~setup.simulation_setup.BCMasks` (currently unused — the per-_edge functions slice by index; masks are reserved for a future vectorised implementation). Returns: Populations with boundary conditions applied. """ f = f_streamed for _edge, _fn in ops: f = _fn(f, f_collision, lattice, _edge) return f return bc_fn
__all__ = ["BCMasks", "build_bc", "build_bc_masks"]