from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Dict, Tuple
import jax
import jax.numpy as jnp
import optax
from jaxopt import LBFGS
from .biot_savart import biot_savart_surface
from .poisson import solve_current_potential, surface_current_from_potential
from .sources import deposit_current_sources
from .torus import TorusSurface
[docs]
@jax.tree_util.register_pytree_node_class
@dataclass(frozen=True)
class SourceParams:
"""Electrode locations and injected currents (unconstrained)."""
theta_src: jnp.ndarray # (Ns,)
phi_src: jnp.ndarray # (Ns,)
currents_raw: jnp.ndarray # (Ns,)
[docs]
def tree_flatten(self):
return (self.theta_src, self.phi_src, self.currents_raw), None
[docs]
@classmethod
def tree_unflatten(cls, aux, children):
theta_src, phi_src, currents_raw = children
return cls(theta_src=theta_src, phi_src=phi_src, currents_raw=currents_raw)
[docs]
def enforce_net_zero(currents_raw: jnp.ndarray) -> jnp.ndarray:
"""Project currents onto the Σ I_i = 0 subspace."""
return currents_raw - jnp.mean(currents_raw)
[docs]
def forward_B(
surface: TorusSurface,
params: SourceParams,
*,
eval_points: jnp.ndarray,
sigma_theta: float,
sigma_phi: float,
sigma_s: float = 1.0,
current_scale: float = 1.0,
tol: float = 1e-10,
maxiter: int = 2_000,
biot_savart_eps: float = 1e-9,
) -> jnp.ndarray:
"""Compute B at `eval_points` from electrode params on the torus surface."""
_, _, _, K = surface_solution(
surface,
params,
sigma_theta=sigma_theta,
sigma_phi=sigma_phi,
sigma_s=sigma_s,
current_scale=current_scale,
tol=tol,
maxiter=maxiter,
)
return biot_savart_surface(surface, K, eval_points, eps=biot_savart_eps)
[docs]
def surface_solution(
surface: TorusSurface,
params: SourceParams,
*,
sigma_theta: float,
sigma_phi: float,
sigma_s: float = 1.0,
current_scale: float = 1.0,
tol: float = 1e-10,
maxiter: int = 2_000,
use_preconditioner: bool = False,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Return (currents, source_density, potential, surface_current)."""
currents = current_scale * enforce_net_zero(params.currents_raw)
s = deposit_current_sources(
surface,
theta_src=params.theta_src,
phi_src=params.phi_src,
currents=currents,
sigma_theta=sigma_theta,
sigma_phi=sigma_phi,
)
# Electrode model: -σ_s Δ_s V = s, K = -σ_s ∇_s V.
# For uniform σ_s, the resulting K is independent of σ_s (only V rescales),
# but we keep σ_s here to match the documented equations and allow future
# extensions (e.g. nonuniform conductivity).
sigma_s_eff = float(sigma_s)
if sigma_s_eff <= 0.0:
raise ValueError("sigma_s must be > 0 for the electrode model.")
V, _ = solve_current_potential(
surface, s / sigma_s_eff, tol=tol, maxiter=maxiter, use_preconditioner=use_preconditioner
)
K = surface_current_from_potential(surface, V, sigma_s=sigma_s)
return currents, s, V, K
[docs]
def make_helical_axis_points(
*,
R_axis: float,
n_points: int,
dtype=jnp.float64,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Convenience: points on the torus magnetic axis (major circle) + basis."""
phi = jnp.linspace(0.0, 2 * jnp.pi, n_points, endpoint=False, dtype=dtype)
c = jnp.cos(phi)
s = jnp.sin(phi)
points = jnp.stack([R_axis * c, R_axis * s, jnp.zeros_like(phi)], axis=-1)
e_r = jnp.stack([c, s, jnp.zeros_like(phi)], axis=-1)
e_phi = jnp.stack([-s, c, jnp.zeros_like(phi)], axis=-1)
e_z = jnp.tile(jnp.array([0.0, 0.0, 1.0], dtype=dtype), (n_points, 1))
return phi, points, e_r, e_phi, e_z
[docs]
def optimize_sources(
surface: TorusSurface,
*,
init: SourceParams,
eval_points: jnp.ndarray,
B_target: jnp.ndarray,
B_scale: float = 1.0,
sigma_theta: float,
sigma_phi: float,
n_steps: int,
lr: float,
reg_currents: float = 1e-6,
reg_positions: float = 0.0,
callback: Callable[[int, Dict[str, float]], None] | None = None,
return_history: bool = False,
) -> SourceParams | tuple[SourceParams, Dict[str, list[float]]]:
"""Optimize electrode params to match a target B field at given points."""
def loss_fn(p: SourceParams) -> tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
B = forward_B(
surface,
p,
eval_points=eval_points,
sigma_theta=sigma_theta,
sigma_phi=sigma_phi,
)
err = (B - B_target) / B_scale
loss_B = jnp.mean(jnp.sum(err * err, axis=-1))
currents = enforce_net_zero(p.currents_raw)
currents_rms = jnp.sqrt(jnp.mean(currents * currents))
loss_reg = reg_currents * jnp.mean(currents * currents)
if reg_positions != 0.0:
loss_reg = loss_reg + reg_positions * (
jnp.mean(p.theta_src * p.theta_src) + jnp.mean(p.phi_src * p.phi_src)
)
loss = loss_B + loss_reg
aux = {"loss": loss, "loss_B": loss_B, "loss_reg": loss_reg, "currents_rms": currents_rms}
return loss, aux
opt = optax.adam(lr)
opt_state = opt.init(init)
@jax.jit
def step(p, s):
(loss, aux), g = jax.value_and_grad(loss_fn, has_aux=True)(p)
updates, s2 = opt.update(g, s, p)
p2 = optax.apply_updates(p, updates)
return p2, s2, aux
params = init
state = opt_state
history: Dict[str, list[float]] = {"loss": [], "loss_B": [], "loss_reg": [], "currents_rms": []}
for k in range(n_steps):
params, state, aux = step(params, state)
if callback is not None:
callback(
k,
{
"loss": float(aux["loss"]),
"loss_B": float(aux["loss_B"]),
"loss_reg": float(aux["loss_reg"]),
"currents_rms": float(aux["currents_rms"]),
},
)
if return_history:
for name in history.keys():
history[name].append(float(aux[name]))
if return_history:
return params, history
return params
[docs]
def optimize_sources_lbfgs(
surface: TorusSurface,
*,
init: SourceParams,
eval_points: jnp.ndarray,
B_target: jnp.ndarray,
B_scale: float = 1.0,
sigma_theta: float,
sigma_phi: float,
maxiter: int,
tol: float = 1e-9,
reg_currents: float = 1e-6,
reg_positions: float = 0.0,
callback: Callable[[int, Dict[str, float]], None] | None = None,
return_history: bool = False,
) -> SourceParams | tuple[SourceParams, Dict[str, list[float]]]:
"""Optimize electrode params using L-BFGS (often faster than Adam for small problems)."""
def loss_fn(p: SourceParams) -> tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
B = forward_B(
surface,
p,
eval_points=eval_points,
sigma_theta=sigma_theta,
sigma_phi=sigma_phi,
)
err = (B - B_target) / B_scale
loss_B = jnp.mean(jnp.sum(err * err, axis=-1))
currents = enforce_net_zero(p.currents_raw)
currents_rms = jnp.sqrt(jnp.mean(currents * currents))
loss_reg = reg_currents * jnp.mean(currents * currents)
if reg_positions != 0.0:
loss_reg = loss_reg + reg_positions * (
jnp.mean(p.theta_src * p.theta_src) + jnp.mean(p.phi_src * p.phi_src)
)
loss = loss_B + loss_reg
aux = {"loss": loss, "loss_B": loss_B, "loss_reg": loss_reg, "currents_rms": currents_rms}
return loss, aux
solver = LBFGS(fun=loss_fn, has_aux=True, maxiter=int(maxiter), tol=float(tol), jit=True)
params = init
state = solver.init_state(params)
history: Dict[str, list[float]] = {"loss": [], "loss_B": [], "loss_reg": [], "currents_rms": []}
for k in range(int(maxiter)):
params, state = solver.update(params, state)
aux = state.aux
if callback is not None:
callback(
k,
{
"loss": float(aux["loss"]),
"loss_B": float(aux["loss_B"]),
"loss_reg": float(aux["loss_reg"]),
"currents_rms": float(aux["currents_rms"]),
},
)
if return_history:
for name in history.keys():
history[name].append(float(aux[name]))
if float(state.error) <= float(tol):
break
if return_history:
return params, history
return params