from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
import jax.numpy as jnp
import numpy as np
[docs]
@dataclass(frozen=True)
class VMECBoundary:
"""Fourier representation of a VMEC boundary surface.
VMEC convention (stellarator-symmetric case):
R(θ,φ) = Σ [RBC(n,m) cos(mθ - n*nfp*φ) + RBS(n,m) sin(mθ - n*nfp*φ)]
Z(θ,φ) = Σ [ZBC(n,m) cos(mθ - n*nfp*φ) + ZBS(n,m) sin(mθ - n*nfp*φ)]
Here we use φ on [0,2π) for the full torus, so the nfp factor is included.
"""
nfp: int
m: np.ndarray # (Nmodes,)
n: np.ndarray # (Nmodes,)
rbc: np.ndarray # (Nmodes,)
rbs: np.ndarray # (Nmodes,)
zbc: np.ndarray # (Nmodes,)
zbs: np.ndarray # (Nmodes,)
_NFP_RE = re.compile(r"(?im)^[ \t]*NFP[ \t]*=[ \t]*([+-]?\d+)")
_COEF_RE = re.compile(
r"(?P<name>RBC|RBS|ZBC|ZBS)\(\s*(?P<n>[+-]?\d+)\s*,\s*(?P<m>[+-]?\d+)\s*\)\s*=\s*(?P<val>[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eEdD][+-]?\d+)?)"
)
[docs]
def read_vmec_boundary(path: str | Path) -> VMECBoundary:
"""Parse a VMEC input file and return the boundary Fourier coefficients."""
path = Path(path)
text = path.read_text()
m_nfp = _NFP_RE.search(text)
if m_nfp is None:
raise ValueError(f"Could not find NFP in VMEC input: {path}")
nfp = int(m_nfp.group(1))
coeffs: dict[str, dict[tuple[int, int], float]] = {k: {} for k in ("RBC", "RBS", "ZBC", "ZBS")}
for m in _COEF_RE.finditer(text):
name = m.group("name")
n = int(m.group("n"))
mm = int(m.group("m"))
val = float(m.group("val").replace("D", "E").replace("d", "e"))
coeffs[name][(n, mm)] = val
keys = set().union(*[set(d.keys()) for d in coeffs.values()])
if not keys:
raise ValueError(f"No boundary coefficients (RBC/RBS/ZBC/ZBS) found in: {path}")
# Sort for determinism.
keys_sorted = sorted(keys, key=lambda nm: (nm[1], nm[0])) # (m, n)
n_arr = np.array([n for (n, m) in keys_sorted], dtype=int)
m_arr = np.array([m for (n, m) in keys_sorted], dtype=int)
def get(name: str) -> np.ndarray:
return np.array([coeffs[name].get((n, m), 0.0) for (n, m) in keys_sorted], dtype=float)
return VMECBoundary(
nfp=nfp,
m=m_arr,
n=n_arr,
rbc=get("RBC"),
rbs=get("RBS"),
zbc=get("ZBC"),
zbs=get("ZBS"),
)
[docs]
def vmec_boundary_RZ_and_derivatives(
boundary: VMECBoundary, *, theta: jnp.ndarray, phi: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Evaluate (R,Z) and first derivatives on a (theta,phi) grid.
theta: (Nθ,), phi: (Nφ,) with phi spanning [0,2π) for the full torus.
Returns arrays of shape (Nθ,Nφ): (R,Z,R_theta,R_phi,Z_theta,Z_phi).
"""
theta = jnp.asarray(theta, dtype=jnp.float64)[:, None] # (Nθ,1)
phi = jnp.asarray(phi, dtype=jnp.float64)[None, :] # (1,Nφ)
m = jnp.asarray(boundary.m, dtype=jnp.float64)[:, None, None] # (M,1,1)
n = jnp.asarray(boundary.n, dtype=jnp.float64)[:, None, None] # (M,1,1)
nfp = float(boundary.nfp)
ang = m * theta[None, :, :] - (n * nfp) * phi[None, :, :]
c = jnp.cos(ang)
s = jnp.sin(ang)
rbc = jnp.asarray(boundary.rbc, dtype=jnp.float64)[:, None, None]
rbs = jnp.asarray(boundary.rbs, dtype=jnp.float64)[:, None, None]
zbc = jnp.asarray(boundary.zbc, dtype=jnp.float64)[:, None, None]
zbs = jnp.asarray(boundary.zbs, dtype=jnp.float64)[:, None, None]
R = jnp.sum(rbc * c + rbs * s, axis=0)
Z = jnp.sum(zbc * c + zbs * s, axis=0)
# Derivatives with respect to theta:
# d/dθ cos(mθ-...) = -m sin(...)
# d/dθ sin(mθ-...) = m cos(...)
R_theta = jnp.sum((-m) * rbc * s + (m) * rbs * c, axis=0)
Z_theta = jnp.sum((-m) * zbc * s + (m) * zbs * c, axis=0)
# Derivatives with respect to phi (full-torus phi so includes nfp):
# d/dφ cos(mθ-n*nfp*φ) = + (n*nfp) sin(...)
# d/dφ sin(mθ-n*nfp*φ) = - (n*nfp) cos(...)
nn = n * nfp
R_phi = jnp.sum((nn) * rbc * s + (-nn) * rbs * c, axis=0)
Z_phi = jnp.sum((nn) * zbc * s + (-nn) * zbs * c, axis=0)
return R, Z, R_theta, R_phi, Z_theta, Z_phi
[docs]
def vmec_boundary_xyz_and_normals(
boundary: VMECBoundary, *, theta: jnp.ndarray, phi: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Return xyz and outward unit normal on the VMEC boundary."""
R, Z, R_theta, R_phi, Z_theta, Z_phi = vmec_boundary_RZ_and_derivatives(boundary, theta=theta, phi=phi)
phi2 = jnp.asarray(phi, dtype=jnp.float64)[None, :]
c = jnp.cos(phi2)
s = jnp.sin(phi2)
x = R * c
y = R * s
z = Z
xyz = jnp.stack([x, y, z], axis=-1)
# Tangents.
r_theta = jnp.stack([R_theta * c, R_theta * s, Z_theta], axis=-1)
r_phi = jnp.stack([R_phi * c - R * s, R_phi * s + R * c, Z_phi], axis=-1)
n = jnp.cross(r_theta, r_phi)
n_hat = n / (jnp.linalg.norm(n, axis=-1, keepdims=True) + 1e-30)
return xyz, n_hat