Source code for torus_solver.metrics
from __future__ import annotations
import jax.numpy as jnp
[docs]
def weighted_mean(x: jnp.ndarray, weights: jnp.ndarray, *, eps: float = 1e-30) -> jnp.ndarray:
"""Weighted mean of a scalar field."""
x = jnp.asarray(x)
w = jnp.asarray(weights)
return jnp.sum(w * x) / (jnp.sum(w) + eps)
[docs]
def weighted_rms(x: jnp.ndarray, weights: jnp.ndarray, *, eps: float = 1e-30) -> jnp.ndarray:
"""Weighted RMS of a scalar field."""
x = jnp.asarray(x)
w = jnp.asarray(weights)
return jnp.sqrt(weighted_mean(x * x, w, eps=eps))
[docs]
def weighted_p_norm(x: jnp.ndarray, weights: jnp.ndarray, *, p: float, eps: float = 1e-30) -> jnp.ndarray:
"""Weighted p-norm proxy (p>=2) that interpolates between RMS (p=2) and max (p→∞)."""
x = jnp.asarray(x)
w = jnp.asarray(weights)
mean = weighted_mean(jnp.abs(x) ** p, w, eps=eps)
return mean ** (1.0 / p)
[docs]
def bn_over_B(B: jnp.ndarray, normals: jnp.ndarray, *, eps: float = 1e-30) -> jnp.ndarray:
"""Compute the normalized normal field B·n/norm(B)."""
B = jnp.asarray(B)
n = jnp.asarray(normals)
Bn = jnp.sum(B * n, axis=-1)
Bmag = jnp.linalg.norm(B, axis=-1)
return Bn / (Bmag + eps)
[docs]
def bn_over_B_metrics(
B: jnp.ndarray, normals: jnp.ndarray, weights: jnp.ndarray, *, eps: float = 1e-30
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Return (Bn_over_B, rms, max_abs) with area weights."""
ratio = bn_over_B(B, normals, eps=eps)
rms = weighted_rms(ratio, weights, eps=eps)
max_abs = jnp.max(jnp.abs(ratio))
return ratio, rms, max_abs