"""Wetting hysteresis optimisation — pure functions.
Ported from :class:`update_timestep.UpdateMultiphaseHysteresis`.
The legacy class stores mutable wetting parameters on ``self`` and
uses ``@partial(jit, static_argnums=(0,))`` which causes JIT cache
bloat. This module replaces it with pure functions that operate on
the :class:`~state.state.WettingState` NamedTuple carried through
``jax.lax.scan``.
All inner optimisation loops use ``optax`` + ``jax.lax.while_loop``
with early convergence exit and are fully jittable.
Design
~~~~~~
``update_wetting_state`` is the top-level entry point. It:
1. Measures contact angles and contact-line locations from ``rho_t_plus1``.
2. Checks whether each side is inside the hysteresis window.
3. Builds per-side objectives (CLL-pin or CA-target, selected via
``jnp.where``).
4. Runs **two** sequential ``jax.lax.while_loop`` optimisations —
one for each side — masking the other side's parameters.
5. Returns an updated :class:`WettingState` — no mutation.
The inner ``_evaluate_with_params`` closure performs a single LBM
step with trial wetting parameters so that ``jax.value_and_grad``
can differentiate through it.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
from tud_lbm.config.config_overview import DEBUG_FLAG
from tud_lbm.operators.wetting._contact_angle import compute_contact_angle
from tud_lbm.operators.wetting._contact_line import compute_contact_line_location
from tud_lbm.operators.wetting._params import WettingParams
from tud_lbm.registry import wetting_operator
if TYPE_CHECKING:
import types
from collections.abc import Callable
from tud_lbm.pipeline.setup import SimulationSetup
from tud_lbm.pipeline.state.state import WettingState
# ── Helpers ──────────────────────────────────────────────────────────
# Neutral values — the inactive parameter is snapped to these when the
# directional split is applied.
_PHI_NEUTRAL: float = 1.0
_D_RHO_NEUTRAL: float = 0.0
def _import_optax() -> types.ModuleType:
"""Import optional ``optax`` dependency with a clear install hint."""
try:
import optax
except ImportError as err:
msg = "The 'optax' package is required for hysteresis wetting.\nInstall it with: pip install optax"
raise ImportError(msg) from err
return optax
def _phi_is_active(
in_window: jnp.ndarray,
above_window: jnp.ndarray,
forward_drift: jnp.ndarray,
) -> jnp.ndarray:
"""Return True if phi is the active parameter for this side.
phi is active (wall needs to become *more* wetting, i.e. lower CA) when:
- ``above_window`` — CA has exceeded ca_advancing and must be pulled down.
- ``in_window & ~forward_drift`` — CL is drifting backward (receding);
resist by increasing wettability.
d_rho is active (wall needs to become *less* wetting, i.e. raise CA) in
all other cases:
- ``in_window & forward_drift`` — CL is advancing; resist by reducing
wettability.
- ``~above_window & ~in_window`` (below window) — CA has dropped below
ca_receding; must be raised.
Args:
in_window: bool scalar — CA is between ca_receding and ca_advancing.
above_window: bool scalar — CA > ca_advancing.
forward_drift: bool scalar — the trial-step CLL has moved in the
advancing direction relative to the stored CLL. Convention:
right side → forward means ``cll_trial > cll_stored``;
left side → forward means ``cll_trial < cll_stored``.
Returns:
Boolean JAX scalar; True means phi is active, False means d_rho is active.
"""
return above_window | (in_window & ~forward_drift)
def _clamp_params(params: WettingParams) -> WettingParams:
"""Clamp wetting parameters to physically reasonable ranges.
Note: ``jnp.clip`` has zero gradient at the boundaries, so a
parameter sitting at a clamp limit receives no further gradient
signal in that direction. The ranges below cover physically
realistic wetting conditions.
"""
return WettingParams(
phi_left=jnp.clip(params.phi_left, 1.0, 1.5),
phi_right=jnp.clip(params.phi_right, 1.0, 1.5),
d_rho_left=jnp.clip(params.d_rho_left, 0.0, 0.3),
d_rho_right=jnp.clip(params.d_rho_right, 0.0, 0.3),
)
def _cost_cll(cll_target: jnp.ndarray, cll_current: jnp.ndarray) -> jnp.ndarray:
"""Huber loss for CLL pinning — smooth gradient near zero, linear elsewhere."""
err = jnp.abs(cll_target - cll_current)
delta = 0.5
return jnp.where(err < delta, 0.5 * err**2, delta * (err - 0.5 * delta))
def _cost_ca(ca_target: jnp.ndarray, ca_current: jnp.ndarray) -> jnp.ndarray:
"""Huber loss for CA targeting — smooth gradient near zero, linear elsewhere."""
err = jnp.abs(ca_target - ca_current)
delta = 5.0 # Degrees
return jnp.where(err < delta, 0.5 * err**2, delta * (err - 0.5 * delta))
def _cost_above(ca_adv: jnp.ndarray, ca_current: jnp.ndarray) -> jnp.ndarray:
"""One-sided Huber loss that penalises only CA values above ca_adv."""
excess = jnp.maximum(ca_current - ca_adv, 0.0)
delta = 5.0 # Degrees
return jnp.where(excess < delta, 0.5 * excess**2, delta * (excess - 0.5 * delta))
def _cost_below(ca_rec: jnp.ndarray, ca_current: jnp.ndarray) -> jnp.ndarray:
"""One-sided Huber loss that penalises only CA values below ca_rec."""
deficit = jnp.maximum(ca_rec - ca_current, 0.0)
delta = 5.0 # Degrees
return jnp.where(deficit < delta, 0.5 * deficit**2, delta * (deficit - 0.5 * delta))
def _mask_left_d_rho(g: WettingParams) -> WettingParams:
z = jnp.zeros_like
return WettingParams(
phi_left=z(g.phi_left),
phi_right=z(g.phi_right),
d_rho_left=g.d_rho_left,
d_rho_right=z(g.d_rho_right),
)
def _mask_left_phi(g: WettingParams) -> WettingParams:
z = jnp.zeros_like
return WettingParams(
phi_left=g.phi_left,
phi_right=z(g.phi_right),
d_rho_left=z(g.d_rho_left),
d_rho_right=z(g.d_rho_right),
)
def _mask_right_d_rho(g: WettingParams) -> WettingParams:
z = jnp.zeros_like
return WettingParams(
phi_left=z(g.phi_left),
phi_right=z(g.phi_right),
d_rho_left=z(g.d_rho_left),
d_rho_right=g.d_rho_right,
)
def _mask_right_phi(g: WettingParams) -> WettingParams:
z = jnp.zeros_like
return WettingParams(
phi_left=z(g.phi_left),
phi_right=g.phi_right,
d_rho_left=z(g.d_rho_left),
d_rho_right=z(g.d_rho_right),
)
# ── Generic optimisation routine ─────────────────────────────────────
def _optimise_single_param(
objective_fn: Callable[[WettingParams], jnp.ndarray],
initial_params: WettingParams,
grad_mask_fn: Callable[[WettingParams], WettingParams],
optimiser: object,
max_iterations: int,
) -> tuple[WettingParams, jnp.ndarray]:
"""Run an ``optax`` optimisation loop with masked gradients.
Uses ``jax.lax.while_loop`` with early exit: the loop terminates
when **either** ``max_iterations`` is reached **or** the loss drops
below ``loss_tol``, whichever comes first.
Args:
objective_fn: ``params → scalar_loss``.
initial_params: Starting :class:`WettingParams`.
grad_mask_fn: ``grads → grads`` that zeros out all but the
target parameter(s).
optimiser: An ``optax`` optimiser instance.
max_iterations: Maximum number of inner steps.
Returns:
``(final_params, final_loss)``.
"""
import optax # lazy import — optional dependency
opt_state = optimiser.init(initial_params)
initial_loss = objective_fn(initial_params)
def cond_fn(carry: tuple) -> jnp.ndarray:
_params, _opt_state, _loss, iteration = carry
return iteration < max_iterations
def body_fn(carry: tuple) -> tuple:
params, opt_state, _loss, iteration = carry
loss, grads = jax.value_and_grad(objective_fn)(params)
grads = grad_mask_fn(grads)
updates, new_opt_state = optimiser.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
new_params = _clamp_params(new_params)
return (new_params, new_opt_state, loss, iteration + 1)
init_carry = (initial_params, opt_state, initial_loss, jnp.array(0))
final_params, _opt_state, final_loss, _iters = jax.lax.while_loop(
cond_fn,
body_fn,
init_carry,
)
return final_params, final_loss
# ── Top-level entry point ────────────────────────────────────────────
def _get_hysteresis_window_chemical_step(setup: SimulationSetup, cll: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Return (ca_advancing, ca_receding) based on CLL position relative to chemical step."""
step_x = setup.config.chemical_step_config["chemical_step_location"] * setup.config.grid_shape[0]
return jax.lax.cond(
cll < step_x,
lambda: (
setup.config.chemical_step_config["ca_advancing_pre_step"],
setup.config.chemical_step_config["ca_receding_pre_step"],
),
lambda: (
setup.config.chemical_step_config["ca_advancing_post_step"],
setup.config.chemical_step_config["ca_receding_post_step"],
),
)
def _update_wetting_state_impl(
wetting: WettingState,
rho_t_plus1: jnp.ndarray,
setup: SimulationSetup,
trial_step_fn: Callable[[WettingParams], tuple[jnp.ndarray, jnp.ndarray]],
*,
ca_adv_left: jnp.ndarray,
ca_rec_left: jnp.ndarray,
ca_adv_right: jnp.ndarray,
ca_rec_right: jnp.ndarray,
) -> WettingState:
"""Shared implementation for hysteresis wetting updates."""
mp = setup.multiphase_params
rho_mean = 0.5 * (mp.rho_l + mp.rho_v)
ca_left_tplus1, ca_right_tplus1 = compute_contact_angle(rho_t_plus1, jnp.array(rho_mean))
cll_left_tplus1, cll_right_tplus1 = compute_contact_line_location(
rho_t_plus1,
ca_left_tplus1,
ca_right_tplus1,
jnp.array(rho_mean),
)
forward_drift_right = cll_right_tplus1 > wetting.cll_right
forward_drift_left = cll_left_tplus1 < wetting.cll_left
in_window_left = (ca_left_tplus1 >= ca_rec_left) & (ca_left_tplus1 <= ca_adv_left)
in_window_right = (ca_right_tplus1 >= ca_rec_right) & (ca_right_tplus1 <= ca_adv_right)
above_window_left = ca_left_tplus1 > ca_adv_left
above_window_right = ca_right_tplus1 > ca_adv_right
phi_active_right = _phi_is_active(in_window_right, above_window_right, forward_drift_right)
phi_active_left = _phi_is_active(in_window_left, above_window_left, forward_drift_left)
hc = setup.config.hysteresis_config
lr_default = hc.get("learning_rate", 0.01)
lr = jnp.where(above_window_right | above_window_left, hc.get("learning_rate_above", 0.05), lr_default)
max_iter = hc.get("max_iterations_above", 50) if hc.get("max_iterations_above") else hc.get("max_iterations", 50)
params = WettingParams(
phi_left=jnp.where(phi_active_left, wetting.phi_left, _PHI_NEUTRAL),
phi_right=jnp.where(phi_active_right, wetting.phi_right, _PHI_NEUTRAL),
d_rho_left=jnp.where(phi_active_left, _D_RHO_NEUTRAL, wetting.d_rho_left),
d_rho_right=jnp.where(phi_active_right, _D_RHO_NEUTRAL, wetting.d_rho_right),
)
optax = _import_optax()
optimiser = optax.adam(lr)
def evaluate_fn(params: WettingParams) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
_, rho_out = trial_step_fn(params)
ca_l, ca_r = compute_contact_angle(rho_out, jnp.array(rho_mean))
cll_l, cll_r = compute_contact_line_location(rho_out, ca_l, ca_r, jnp.array(rho_mean))
return ca_l, ca_r, cll_l, cll_r
def left_objective(p: WettingParams) -> jnp.ndarray:
ca_l, _, cll_l, _ = evaluate_fn(p)
cost_in = _cost_cll(wetting.cll_left, cll_l)
cost_below = _cost_below(ca_rec_left, ca_l)
cost_above = _cost_above(ca_adv_left, ca_l)
return jnp.where(in_window_left, cost_in, jnp.where(above_window_left, cost_above, cost_below))
def right_objective(p: WettingParams) -> jnp.ndarray:
_, ca_r, _, cll_r = evaluate_fn(p)
cost_in = _cost_cll(wetting.cll_right, cll_r)
cost_below = _cost_below(ca_rec_right, ca_r)
cost_above = _cost_above(ca_adv_right, ca_r)
return jnp.where(in_window_right, cost_in, jnp.where(above_window_right, cost_above, cost_below))
def _opt_left(p: WettingParams) -> WettingParams:
return jax.lax.cond(
phi_active_left,
lambda pp: _optimise_single_param(left_objective, pp, _mask_left_phi, optimiser, max_iter)[0],
lambda pp: _optimise_single_param(left_objective, pp, _mask_left_d_rho, optimiser, max_iter)[0],
p,
)
def _opt_right(p: WettingParams) -> WettingParams:
return jax.lax.cond(
phi_active_right,
lambda pp: _optimise_single_param(right_objective, pp, _mask_right_phi, optimiser, max_iter)[0],
lambda pp: _optimise_single_param(right_objective, pp, _mask_right_d_rho, optimiser, max_iter)[0],
p,
)
new_params = _opt_right(_opt_left(params))
# Fallback: if phi path is selected but stays neutral, retry with d_rho
# warm-started from the stored accumulated value.
def _fallback_d_rho_left(p: WettingParams) -> WettingParams:
fallback = WettingParams(
phi_left=_PHI_NEUTRAL,
phi_right=p.phi_right,
d_rho_left=wetting.d_rho_left,
d_rho_right=p.d_rho_right,
)
return _optimise_single_param(left_objective, fallback, _mask_left_d_rho, optimiser, max_iter)[0]
def _fallback_d_rho_right(p: WettingParams) -> WettingParams:
fallback = WettingParams(
phi_left=p.phi_left,
phi_right=_PHI_NEUTRAL,
d_rho_left=p.d_rho_left,
d_rho_right=wetting.d_rho_right,
)
return _optimise_single_param(right_objective, fallback, _mask_right_d_rho, optimiser, max_iter)[0]
final_params = jax.lax.cond(
phi_active_left & (new_params.phi_left < _PHI_NEUTRAL),
_fallback_d_rho_left,
lambda p: p,
new_params,
)
final_params = jax.lax.cond(
phi_active_right & (new_params.phi_right < _PHI_NEUTRAL),
_fallback_d_rho_right,
lambda p: p,
final_params,
)
if DEBUG_FLAG:
phi_engaged_right = phi_active_right & (final_params.phi_right > _PHI_NEUTRAL)
fallback_right = phi_active_right & ~phi_engaged_right
phi_engaged_left = phi_active_left & (final_params.phi_left > _PHI_NEUTRAL)
fallback_left = phi_active_left & ~phi_engaged_left
# mode: 0=d_rho(normal), 1=phi(engaged), 2=d_rho(fallback)
mode_right = jnp.where(phi_engaged_right, jnp.array(1), jnp.where(fallback_right, jnp.array(2), jnp.array(0)))
mode_left = jnp.where(phi_engaged_left, jnp.array(1), jnp.where(fallback_left, jnp.array(2), jnp.array(0)))
jax.debug.print(
"\n[R] CA={ca:.3f}° (adv={ca_adv:.1f}° rec={ca_rec:.1f}°) | CLL={cll:.3f} | "
"mode={mode}(0=d_rho,1=phi,2=fb) | "
"phi: {phi:.6f} | "
"d_rho: {d_rho:.6f} | "
"loss={loss:.3e}",
ca=ca_right_tplus1,
ca_adv=ca_adv_right,
ca_rec=ca_rec_right,
cll=cll_right_tplus1,
mode=mode_right,
phi=final_params.phi_right,
d_rho=final_params.d_rho_right,
loss=right_objective(final_params),
)
jax.debug.print(
"[L] CA={ca:.3f}° (adv={ca_adv:.1f}° rec={ca_rec:.1f}°) | CLL={cll:.3f} | "
"mode={mode}(0=d_rho,1=phi,2=fb) | "
"phi: {phi:.6f} | "
"d_rho: {d_rho:.6f} | "
"loss={loss:.3e}",
ca=ca_left_tplus1,
ca_adv=ca_adv_left,
ca_rec=ca_rec_left,
cll=cll_left_tplus1,
mode=mode_left,
phi=final_params.phi_left,
d_rho=final_params.d_rho_left,
loss=left_objective(final_params),
)
return wetting._replace(
phi_left=final_params.phi_left,
phi_right=final_params.phi_right,
d_rho_left=final_params.d_rho_left,
d_rho_right=final_params.d_rho_right,
ca_left=ca_left_tplus1,
ca_right=ca_right_tplus1,
cll_left=jnp.where(in_window_left, wetting.cll_left, cll_left_tplus1),
cll_right=jnp.where(in_window_right, wetting.cll_right, cll_right_tplus1),
)
@wetting_operator(name="hysteresis")
[docs]
def update_wetting_state(
wetting: WettingState,
rho_t_plus1: jnp.ndarray,
setup: SimulationSetup,
*,
trial_step_fn: Callable[[WettingParams], tuple[jnp.ndarray, jnp.ndarray]] | None = None,
) -> WettingState:
"""Pure JAX update of wetting / hysteresis parameters.
This replaces the mutable
:class:`~update_timestep.UpdateMultiphaseHysteresis.__call__`
method. It operates entirely on the :class:`WettingState`
NamedTuple and returns a new instance — no side-effects.
Each side (left and right) is optimised **independently** in its
own ``jax.lax.while_loop``, masking out the other side's
parameters. This gives each side clean gradients and clean Adam
state at the cost of two trial-step evaluations per outer
iteration instead of one.
Hysteresis regime is selected strictly from the current contact angle:
pinned when ``ca_receding <= ca <= ca_advancing``; otherwise CA-targeting
toward the exceeded bound.
Args:
wetting: Current :class:`WettingState`.
rho_t_plus1: Density field, shape ``(nx, ny, nz, 1, 1)``.
setup: :class:`~setup.simulation_setup.SimulationSetup`
(closed-over, not traced).
f_t: Pre-step populations, shape ``(nx, ny, nz, q, 1)``.
force_tot: Total force from the current step (optional).
trial_step_fn: Callable ``(WettingParams) → (f_out, rho_out)``
that evaluates a single multiphase physics pass with trial
wetting parameters. Required; no default is provided.
Returns:
Updated :class:`WettingState`.
"""
hc = setup.config.hysteresis_config
ca_adv = hc["ca_advancing"]
ca_rec = hc["ca_receding"]
return _update_wetting_state_impl(
wetting,
rho_t_plus1,
setup,
trial_step_fn,
ca_adv_left=ca_adv,
ca_rec_left=ca_rec,
ca_adv_right=ca_adv,
ca_rec_right=ca_rec,
)
@wetting_operator(name="chemical_step_hysteresis")
[docs]
def update_wetting_state_chemical_step(
wetting: WettingState,
rho_t_plus1: jnp.ndarray,
setup: SimulationSetup,
*,
trial_step_fn: Callable[[WettingParams], tuple[jnp.ndarray, jnp.ndarray]],
) -> WettingState:
"""Hysteresis update where CA targets are determined per-side by chemical step position.
Identical to update_wetting_state except (ca_advancing, ca_receding) for each
side is selected by comparing that side's contact-line location against the
chemical step position, rather than using a single global hysteresis window.
Args:
wetting: Current WettingState. wetting.cll_left and wetting.cll_right
are used to determine which region each contact line occupies.
rho_t_plus1: Post-step density field.
setup: SimulationSetup. Must carry chemical_step_config with keys:
chemical_step_location, ca_advancing_pre_step, ca_receding_pre_step,
ca_advancing_post_step, ca_receding_post_step.
trial_step_fn: Callable (WettingParams) -> (f_out, rho_out).
Provided by the step function via partial application.
Returns:
Updated WettingState with optimised wetting parameters and measured CA/CLL.
"""
mp = setup.multiphase_params
rho_mean = 0.5 * (mp.rho_l + mp.rho_v)
# 1. Measure current contact angles and contact-line locations
ca_left_tplus1, ca_right_tplus1 = compute_contact_angle(rho_t_plus1, jnp.array(rho_mean))
cll_left_tplus1, cll_right_tplus1 = compute_contact_line_location(
rho_t_plus1,
ca_left_tplus1,
ca_right_tplus1,
jnp.array(rho_mean),
)
# Use current measured CLL to select the active pre/post hysteresis window.
ca_adv_left, ca_rec_left = _get_hysteresis_window_chemical_step(setup, cll_left_tplus1)
ca_adv_right, ca_rec_right = _get_hysteresis_window_chemical_step(setup, cll_right_tplus1)
return _update_wetting_state_impl(
wetting,
rho_t_plus1,
setup,
trial_step_fn,
ca_adv_left=ca_adv_left,
ca_rec_left=ca_rec_left,
ca_adv_right=ca_adv_right,
ca_rec_right=ca_rec_right,
)