Source code for torus_solver.targets

from __future__ import annotations

from dataclasses import dataclass

import jax.numpy as jnp

from .vmec import VMECBoundary, vmec_boundary_RZ_and_derivatives


[docs] @dataclass(frozen=True) class FitResult: """Shift/scale parameters used to fit an (R,Z) surface inside a circular torus.""" shift_R: float scale_rho: float rho_max_before_m: float
[docs] @dataclass(frozen=True) class TargetSurface: """A target surface sampled on a (theta,phi) grid.""" theta: jnp.ndarray # (Nθ,) phi: jnp.ndarray # (Nφ,) xyz: jnp.ndarray # (Nθ,Nφ,3) normals: jnp.ndarray # (Nθ,Nφ,3) outward unit normals weights: jnp.ndarray # (Nθ,Nφ) area weights (|r_theta x r_phi|) fit: FitResult
[docs] def fit_RZ_surface_into_torus( *, R: jnp.ndarray, Z: jnp.ndarray, torus_R0: float, torus_a: float, fit_margin: float, ) -> FitResult: """Compute shift/scale so the (R,Z) surface fits inside the torus minor radius.""" # Shift so mean(R) matches torus_R0. R_mean = jnp.mean(R) shift_R = float(torus_R0 - float(R_mean)) # Scale the radial distance from the torus axis if needed. R_shift = R + shift_R dR = R_shift - float(torus_R0) rho = jnp.sqrt(dR * dR + Z * Z + 1e-30) rho_max = float(jnp.max(rho)) target = float(fit_margin * torus_a) scale_rho = 1.0 if rho_max > target: scale_rho = float(target / rho_max) return FitResult(shift_R=shift_R, scale_rho=scale_rho, rho_max_before_m=rho_max)
[docs] def apply_fit_to_RZ_and_derivatives( *, R: jnp.ndarray, Z: jnp.ndarray, R_theta: jnp.ndarray, R_phi: jnp.ndarray, Z_theta: jnp.ndarray, Z_phi: jnp.ndarray, torus_R0: float, fit: FitResult, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Apply `FitResult` shift/scale to (R,Z) and first derivatives.""" R_shift = R + float(fit.shift_R) dR = R_shift - float(torus_R0) R_fit = float(torus_R0) + float(fit.scale_rho) * dR Z_fit = float(fit.scale_rho) * Z R_theta_fit = float(fit.scale_rho) * R_theta R_phi_fit = float(fit.scale_rho) * R_phi Z_theta_fit = float(fit.scale_rho) * Z_theta Z_phi_fit = float(fit.scale_rho) * Z_phi return R_fit, Z_fit, R_theta_fit, R_phi_fit, Z_theta_fit, Z_phi_fit
[docs] def RZ_and_derivatives_to_xyz_normals_weights( *, R: jnp.ndarray, Z: jnp.ndarray, R_theta: jnp.ndarray, R_phi: jnp.ndarray, Z_theta: jnp.ndarray, Z_phi: jnp.ndarray, phi: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Build xyz, unit normals, and area weights from (R,Z) and derivatives.""" phi2 = jnp.asarray(phi, dtype=jnp.float64)[None, :] c = jnp.cos(phi2) s = jnp.sin(phi2) x = R * c y = R * s z = Z xyz = jnp.stack([x, y, z], axis=-1) r_theta_xyz = jnp.stack([R_theta * c, R_theta * s, Z_theta], axis=-1) r_phi_xyz = jnp.stack([R_phi * c - R * s, R_phi * s + R * c, Z_phi], axis=-1) n_vec = jnp.cross(r_theta_xyz, r_phi_xyz) n_norm = jnp.linalg.norm(n_vec, axis=-1) n_hat = n_vec / (n_norm[..., None] + 1e-30) return xyz, n_hat, n_norm
[docs] def vmec_target_surface( boundary: VMECBoundary, *, torus_R0: float, torus_a: float, fit_margin: float, n_theta: int, n_phi: int, dtype=jnp.float64, ) -> TargetSurface: """Sample a VMEC boundary, fit it inside the torus, and return xyz/normals/weights.""" theta = jnp.linspace(0.0, 2 * jnp.pi, int(n_theta), endpoint=False, dtype=dtype) phi = jnp.linspace(0.0, 2 * jnp.pi, int(n_phi), endpoint=False, dtype=dtype) R, Z, R_th, R_ph, Z_th, Z_ph = vmec_boundary_RZ_and_derivatives(boundary, theta=theta, phi=phi) fit = fit_RZ_surface_into_torus(R=R, Z=Z, torus_R0=torus_R0, torus_a=torus_a, fit_margin=fit_margin) R_fit, Z_fit, R_th_fit, R_ph_fit, Z_th_fit, Z_ph_fit = apply_fit_to_RZ_and_derivatives( R=R, Z=Z, R_theta=R_th, R_phi=R_ph, Z_theta=Z_th, Z_phi=Z_ph, torus_R0=torus_R0, fit=fit, ) xyz, normals, weights = RZ_and_derivatives_to_xyz_normals_weights( R=R_fit, Z=Z_fit, R_theta=R_th_fit, R_phi=R_ph_fit, Z_theta=Z_th_fit, Z_phi=Z_ph_fit, phi=phi ) return TargetSurface(theta=theta, phi=phi, xyz=xyz, normals=normals, weights=weights, fit=fit)
[docs] def circular_torus_target_surface( *, R0: float, a: float, n_theta: int, n_phi: int, dtype=jnp.float64, ) -> TargetSurface: """Return a circular torus target surface (xyz, normals, weights) on a (theta,phi) grid. This helper is useful both for: - benchmarking/validation: axisymmetry implies B·n=0 on any interior torus for a purely toroidal field - GUI/optimization: choosing a simple reference surface inside the winding surface """ theta = jnp.linspace(0.0, 2 * jnp.pi, int(n_theta), endpoint=False, dtype=dtype) phi = jnp.linspace(0.0, 2 * jnp.pi, int(n_phi), endpoint=False, dtype=dtype) th = theta[:, None] ones_phi = jnp.ones((1, int(n_phi)), dtype=dtype) R_line = float(R0) + float(a) * jnp.cos(th) # (Nθ,1) Z_line = float(a) * jnp.sin(th) # (Nθ,1) R = R_line * ones_phi # (Nθ,Nφ) Z = Z_line * ones_phi # (Nθ,Nφ) R_theta = (-float(a) * jnp.sin(th)) * ones_phi Z_theta = (float(a) * jnp.cos(th)) * ones_phi R_phi = jnp.zeros((int(n_theta), int(n_phi)), dtype=dtype) Z_phi = jnp.zeros((int(n_theta), int(n_phi)), dtype=dtype) xyz, normals, weights = RZ_and_derivatives_to_xyz_normals_weights( R=R, Z=Z, R_theta=R_theta, R_phi=R_phi, Z_theta=Z_theta, Z_phi=Z_phi, phi=phi ) fit = FitResult(shift_R=0.0, scale_rho=1.0, rho_max_before_m=float(a)) return TargetSurface(theta=theta, phi=phi, xyz=xyz, normals=normals, weights=weights, fit=fit)