Source code for tud_lbm.config.jax_config
"""JAX configuration settings for TUD-LBM.
This module provides centralized JAX configuration options.
Import and call `configure_jax()` at the start of your script to apply settings.
"""
# Set environment variable to prevent JAX from pre-allocating all GPU memory.
import os
import jax
from tud_lbm.config.config_overview import DISABLE_JIT
from tud_lbm.config.config_overview import ENABLE_X64
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
[docs]
def configure_jax(
enable_x64: bool | None = None,
disable_jit: bool | None = None,
) -> None:
"""Configure JAX settings.
Call this function at the start of your script to apply JAX configuration.
If parameters are None, the module-level defaults are used.
Args:
enable_x64: Enable 64-bit precision. Defaults to module-level ENABLE_X64.
disable_jit: Disable JIT compilation for debugging. Defaults to module-level DISABLE_JIT.
"""
x64 = enable_x64 if enable_x64 is not None else ENABLE_X64
no_jit = disable_jit if disable_jit is not None else DISABLE_JIT
jax.config.update("jax_enable_x64", x64)
jax.config.update("jax_disable_jit", no_jit)