from __future__ import annotations
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
import numpy as np
import jax
import jax.numpy as jnp
import optax
from .biot_savart import MU0, biot_savart_surface
from .fieldline import trace_field_lines_batch
from .fields import tokamak_like_field
from .optimize import SourceParams
from .poisson import solve_current_potential, surface_current_from_potential
from .sources import deposit_current_sources
from .targets import vmec_target_surface
from .torus import TorusSurface, make_torus_surface
from .vmec import read_vmec_boundary
try:
import vtk # type: ignore
from vtk.util.numpy_support import numpy_to_vtk # type: ignore
except Exception as e: # pragma: no cover
vtk = None
numpy_to_vtk = None
_vtk_import_error = e
ScalarName = Literal["|K|", "V", "s", "K_theta", "K_phi"]
CutScalarName = Literal["|K|", "V", "K_theta", "K_phi"]
# Avoid hard runtime dependence on VTK in type hints (important for docs builds).
VtkPolyData = Any
VtkActor = Any
if vtk is not None: # pragma: no cover
VtkPolyData = vtk.vtkPolyData
VtkActor = vtk.vtkActor
[docs]
@dataclass(frozen=True)
class GUIConfig:
# Geometry + numerics (keep small for interactivity).
R0: float = 3.0
a: float = 1.0
n_theta: int = 32
n_phi: int = 32
sigma_theta: float = 0.25
sigma_phi: float = 0.25
sigma_s: float = 1.0
cg_tol: float = 1e-8
cg_maxiter: int = 800
# Electrodes
n_electrodes_max: int = 32
current_default_A: float = 1000.0
current_slider_max_A: float = 3000.0
# Biot–Savart / tracing
biot_savart_eps: float = 1e-8
n_fieldlines: int = 12
fieldline_steps: int = 500
fieldline_step_size_m: float = 0.03
# Optional background field for field-line tracing / visualization:
# ideal toroidal B ~ 1/R with magnitude Bext0 at R=R0.
Bext0: float = 1e-4 # Tesla at R=R0 (0 disables)
# Optional background poloidal field (tokamak-like proxy) for field-line tracing:
# B_pol = Bpol0 * (R0/R) e_theta.
Bpol0: float = 0.0 # Tesla at R=R0 (0 disables)
bg_field_default_on: bool = False
bg_poloidal_default_on: bool = False
# Rendering
surface_opacity: float = 0.35
window_size: tuple[int, int] = (1200, 820)
[docs]
@dataclass(frozen=True)
class CutGUIConfig:
# Geometry + numerics (keep small for interactivity).
R0: float = 3.0
a: float = 0.3
n_theta: int = 32
n_phi: int = 32
theta_cut: float = float(np.pi) # where the cut/jump appears in the plotted V
sigma_s: float = 1.0
# "Battery": voltage drop across the cut (signed).
V_cut_default: float = 1.0
V_cut_slider_max: float = 5.0
# Optional extra electrodes (current sources/sinks) on top of the cut-driven current.
sigma_theta: float = 0.25
sigma_phi: float = 0.25
cg_tol: float = 1e-8
cg_maxiter: int = 800
n_electrodes_max: int = 32
current_default_A: float = 1000.0
current_slider_max_A: float = 3000.0
# Biot–Savart / tracing
biot_savart_eps: float = 1e-8
n_fieldlines: int = 12
fieldline_steps: int = 500
fieldline_step_size_m: float = 0.03
# Optional background field for field-line tracing / visualization:
# ideal toroidal B ~ 1/R with magnitude Bext0 at R=R0.
Bext0: float = 1e-4 # Tesla at R=R0 (0 disables)
# Optional background poloidal field (tokamak-like proxy) for field-line tracing:
# B_pol = Bpol0 * (R0/R) e_theta.
Bpol0: float = 0.0 # Tesla at R=R0 (0 disables)
bg_field_default_on: bool = False
bg_poloidal_default_on: bool = False
# Rendering
surface_opacity: float = 0.35
window_size: tuple[int, int] = (1200, 820)
[docs]
@dataclass(frozen=True)
class VmecOptGUIConfig:
# VMEC target surface (boundary) inside the circular torus.
vmec_input: str = "examples/data/vmec/input.QA_nfp2"
surf_n_theta: int = 32
surf_n_phi: int = 56
fit_margin: float = 0.7
# Winding surface (circular torus).
R0: float = 1.0
a: float = 0.3
n_theta: int = 32
n_phi: int = 32
# Background field (ideal toroidal 1/R).
B0: float = 1.0 # Tesla at R=R0
# Optional background poloidal field (tokamak-like proxy) at R=R0.
Bpol0: float = 0.0 # Tesla at R=R0
# For diagnostics: allow tracing field lines with/without the background field.
trace_include_bg_default_on: bool = True
# Electrodes (fixed size arrays for interactivity).
n_electrodes_max: int = 64
n_electrodes_init: int = 32
current_default_A: float = 1e6
current_slider_max_A: float = 1e7
# Electrode model.
sigma_theta: float = 0.25
sigma_phi: float = 0.25
sigma_s: float = 1.0
# Scaling: I_phys = current_scale * currents_raw (then projected to net-zero over active electrodes).
current_scale: float | None = None
init_current_raw_rms: float = 1.0
# Optimization controls.
lr: float = 1e-2
reg_currents: float = 1e-3 # weight on ⟨(I/I_scale)^2⟩
optimize_positions: bool = True
steps_per_opt: int = 25
bn_p: int = 8 # use p-norm objective to reduce max|Bn/B| (p>=2; p→∞ approximates max)
# Poisson solve.
cg_tol: float = 1e-10
cg_maxiter: int = 2000
use_preconditioner: bool = False
# Biot–Savart / tracing.
biot_savart_eps: float = 1e-8
n_fieldlines: int = 12
fieldline_steps: int = 500
fieldline_step_size_m: float = 0.03
# Rendering.
surface_opacity: float = 0.35
target_opacity: float = 0.25
window_size: tuple[int, int] = (1300, 880)
def _require_vtk() -> None: # pragma: no cover
if vtk is None:
raise ImportError(
"VTK is required for the interactive GUI. "
"Install with `pip install vtk` (or `pip install .[gui]` if you add extras)."
) from _vtk_import_error
def _setup_text_actor(
actor, *, x: float, y: float, font_size: int = 16, top: bool = True
) -> None: # pragma: no cover
"""Configure a vtkTextActor for robust visibility on HiDPI / resized windows.
We anchor text using normalized viewport coordinates so it stays on-screen even
when the window is resized or when Cocoa/Retina scaling is active.
"""
prop = actor.GetTextProperty()
prop.SetFontSize(int(font_size))
prop.SetColor(0.0, 0.0, 0.0)
prop.SetJustificationToLeft()
if top:
prop.SetVerticalJustificationToTop()
else:
prop.SetVerticalJustificationToBottom()
# Normalized viewport coordinates: (0,0)=bottom-left, (1,1)=top-right.
actor.GetPositionCoordinate().SetCoordinateSystemToNormalizedViewport()
actor.SetPosition(float(x), float(y))
# Keep pixel-based font size rather than scaling with viewport (if supported).
try:
actor.SetTextScaleModeToNone()
except Exception:
pass
def _resolve_vmec_input_path(vmec_input: str) -> Path:
p = Path(vmec_input)
if p.exists():
return p
repo_root = Path(__file__).resolve().parents[2]
candidates = [
repo_root / vmec_input, # allow paths relative to repo root
repo_root / "examples" / vmec_input,
repo_root / "examples" / "data" / "vmec" / Path(vmec_input).name,
]
for cand in candidates:
if cand.exists():
return cand
msg = (
"VMEC input file not found.\n"
f" got: {vmec_input}\n"
f" tried: {p.resolve()}\n"
+ "".join(f" tried: {c}\n" for c in candidates)
+ "Tip: run with `--vmec-input examples/data/vmec/input.QA_nfp2` (from the repo root).\n"
)
raise FileNotFoundError(msg)
[docs]
def torus_xyz(R0: float, a: float, theta: np.ndarray, phi: np.ndarray) -> np.ndarray:
"""Map (theta,phi) -> xyz on the circular torus."""
theta = np.asarray(theta, dtype=float)
phi = np.asarray(phi, dtype=float)
R = R0 + a * np.cos(theta)
x = R * np.cos(phi)
y = R * np.sin(phi)
z = a * np.sin(theta)
return np.stack([x, y, z], axis=-1)
[docs]
def torus_angles_from_point(R0: float, p: np.ndarray) -> tuple[float, float]:
"""Map xyz -> (theta,phi) for points on/near the torus surface."""
x, y, z = map(float, p)
phi = np.arctan2(y, x) % (2 * np.pi)
R = np.sqrt(x * x + y * y)
theta = np.arctan2(z, R - R0) % (2 * np.pi)
return theta, phi
def _wrap_angle_np(x: np.ndarray) -> np.ndarray:
return (x + np.pi) % (2 * np.pi) - np.pi
def _cut_phase_theta_np(*, theta: np.ndarray, R0: float, a: float, theta_cut: float) -> np.ndarray:
"""Phase f(θ) in [0,1) with a jump at theta_cut (for plotting a cut potential)."""
theta = np.asarray(theta, dtype=float)
theta_cut = float(theta_cut) % (2 * np.pi)
diff = _wrap_angle_np(theta - theta_cut)
k0 = int(np.argmin(np.abs(diff)))
R = R0 + a * np.cos(theta)
q = 1.0 / R
q_roll = np.roll(q, -k0)
dtheta = float(2 * np.pi / theta.size)
total = float(np.sum(q_roll) * dtheta)
cum = np.concatenate([[0.0], np.cumsum(q_roll[:-1])]) * dtheta
f_roll = cum / total
return np.roll(f_roll, k0)
def _build_torus_polydata(xyz: np.ndarray) -> VtkPolyData: # pragma: no cover
_require_vtk()
n_theta, n_phi, _ = xyz.shape
pts = xyz.reshape((-1, 3))
poly = vtk.vtkPolyData()
vtk_points = vtk.vtkPoints()
vtk_points.SetData(numpy_to_vtk(pts, deep=True))
poly.SetPoints(vtk_points)
cells = vtk.vtkCellArray()
for i in range(n_theta):
i2 = (i + 1) % n_theta
for j in range(n_phi):
j2 = (j + 1) % n_phi
ids = [
i * n_phi + j,
i2 * n_phi + j,
i2 * n_phi + j2,
i * n_phi + j2,
]
quad = vtk.vtkQuad()
for k, pid in enumerate(ids):
quad.GetPointIds().SetId(k, int(pid))
cells.InsertNextCell(quad)
poly.SetPolys(cells)
return poly
def _build_fieldlines_polydata(n_lines: int, n_pts: int) -> VtkPolyData: # pragma: no cover
_require_vtk()
poly = vtk.vtkPolyData()
points = vtk.vtkPoints()
points.SetNumberOfPoints(n_lines * n_pts)
poly.SetPoints(points)
lines = vtk.vtkCellArray()
for i in range(n_lines):
pl = vtk.vtkPolyLine()
pl.GetPointIds().SetNumberOfIds(n_pts)
base = i * n_pts
for k in range(n_pts):
pl.GetPointIds().SetId(k, base + k)
lines.InsertNextCell(pl)
poly.SetLines(lines)
return poly
[docs]
class TorusElectrodeGUI: # pragma: no cover
"""VTK GUI: click to add/move electrodes; slider to change current; see field lines."""
def __init__(self, cfg: GUIConfig, *, initial_electrodes: dict | None = None):
_require_vtk()
jax.config.update("jax_enable_x64", True)
self.cfg = cfg
self.surface: TorusSurface = make_torus_surface(
R0=cfg.R0, a=cfg.a, n_theta=cfg.n_theta, n_phi=cfg.n_phi
)
# Electrode state (fixed size to avoid recompilation when user adds/removes).
self.N = cfg.n_electrodes_max
self.theta_src = np.zeros((self.N,), dtype=float)
self.phi_src = np.zeros((self.N,), dtype=float)
self.currents_raw = np.zeros((self.N,), dtype=float)
self.active = np.zeros((self.N,), dtype=float)
if initial_electrodes is not None:
th = np.asarray(initial_electrodes.get("theta", []), dtype=float)
ph = np.asarray(initial_electrodes.get("phi", []), dtype=float)
I = np.asarray(initial_electrodes.get("I", []), dtype=float)
n0 = min(self.N, th.size, ph.size, I.size)
self.theta_src[:n0] = th[:n0]
self.phi_src[:n0] = ph[:n0]
self.currents_raw[:n0] = I[:n0]
self.active[:n0] = 1.0
self.selected: int | None = int(np.argmax(self.active)) if np.any(self.active) else None
self.mode: Literal["none", "add_source", "add_sink", "move"] = "none"
# Precompute unit vectors for K decomposition.
self._e_theta = self.surface.r_theta / self.surface.a
self._e_phi = self.surface.r_phi / jnp.sqrt(self.surface.G)[..., None]
# Field line seeds (inside the torus).
theta_seed = jnp.linspace(0.0, 2 * jnp.pi, cfg.n_fieldlines, endpoint=False)
rho = 0.5 * cfg.a
R = cfg.R0 + rho * jnp.cos(theta_seed)
Z = rho * jnp.sin(theta_seed)
self._seeds = jnp.stack([R, jnp.zeros_like(R), Z], axis=-1)
self.scalar_name: ScalarName = "|K|"
self.show_fieldlines = True
self.Bext0 = float(self.cfg.Bext0)
self.Bpol0 = float(self.cfg.Bpol0)
self.include_bg_tor = bool(self.cfg.bg_field_default_on and self.Bext0 != 0.0)
self.include_bg_pol = bool(self.cfg.bg_poloidal_default_on and self.Bpol0 != 0.0)
# Cached computed state (numpy) so we can update visualization (e.g. scalar choice)
# without re-running the JAX solve.
self._cache: dict[str, np.ndarray] = {}
self._traj_cache: np.ndarray | None = None
self._Iproj_cache: np.ndarray | None = None
# Text entry ("textbox") state.
self._edit_mode: Literal["none", "current"] = "none"
self._edit_buffer: str = ""
# Build VTK scene.
self._build_scene()
# Compile compute function once.
self._compute_jit = jax.jit(self._compute_state)
print("Compiling JAX pipeline (first update may take a moment)...")
self.update_solution()
def _build_scene(self) -> None:
cfg = self.cfg
# Renderer / window / interactor
self.renderer = vtk.vtkRenderer()
self.renderer.SetBackground(1.0, 1.0, 1.0)
self.window = vtk.vtkRenderWindow()
self.window.AddRenderer(self.renderer)
self.window.SetSize(*cfg.window_size)
self.interactor = vtk.vtkRenderWindowInteractor()
self.interactor.SetRenderWindow(self.window)
style = vtk.vtkInteractorStyleTrackballCamera()
self.interactor.SetInteractorStyle(style)
# Torus surface actor (with scalars updated each solve).
xyz = np.asarray(self.surface.r)
self.torus_poly = _build_torus_polydata(xyz)
self.surface_scalars_np = np.zeros((cfg.n_theta * cfg.n_phi,), dtype=np.float32)
self.surface_scalars_vtk = numpy_to_vtk(self.surface_scalars_np, deep=False)
self.surface_scalars_vtk.SetName("scalar")
self.torus_poly.GetPointData().SetScalars(self.surface_scalars_vtk)
self.lut = vtk.vtkLookupTable()
self.lut.SetNumberOfTableValues(256)
self.lut.Build()
self.torus_mapper = vtk.vtkPolyDataMapper()
self.torus_mapper.SetInputData(self.torus_poly)
self.torus_mapper.SetLookupTable(self.lut)
self.torus_mapper.SetScalarModeToUsePointData()
self.torus_mapper.ScalarVisibilityOn()
self.torus_actor = vtk.vtkActor()
self.torus_actor.SetMapper(self.torus_mapper)
self.torus_actor.GetProperty().SetOpacity(cfg.surface_opacity)
self.torus_actor.GetProperty().SetInterpolationToPhong()
self.torus_actor.GetProperty().SetSpecular(0.2)
self.torus_actor.GetProperty().SetSpecularPower(30.0)
self.renderer.AddActor(self.torus_actor)
# Axis curve for reference.
axis_phi = np.linspace(0, 2 * np.pi, 200, endpoint=True)
axis = np.stack([cfg.R0 * np.cos(axis_phi), cfg.R0 * np.sin(axis_phi), 0.0 * axis_phi], axis=-1)
axis_poly = vtk.vtkPolyData()
axis_points = vtk.vtkPoints()
axis_points.SetData(numpy_to_vtk(axis, deep=True))
axis_poly.SetPoints(axis_points)
axis_lines = vtk.vtkCellArray()
pl = vtk.vtkPolyLine()
pl.GetPointIds().SetNumberOfIds(axis.shape[0])
for i in range(axis.shape[0]):
pl.GetPointIds().SetId(i, i)
axis_lines.InsertNextCell(pl)
axis_poly.SetLines(axis_lines)
axis_mapper = vtk.vtkPolyDataMapper()
axis_mapper.SetInputData(axis_poly)
axis_actor = vtk.vtkActor()
axis_actor.SetMapper(axis_mapper)
axis_actor.GetProperty().SetColor(0.0, 0.0, 0.0)
axis_actor.GetProperty().SetLineWidth(2.0)
self.renderer.AddActor(axis_actor)
# Field lines actor (updated each solve).
n_pts_line = self.cfg.fieldline_steps + 1
self.field_poly = _build_fieldlines_polydata(self.cfg.n_fieldlines, n_pts_line)
self.field_points_np = np.zeros(
(self.cfg.n_fieldlines * n_pts_line, 3), dtype=np.float32
)
self.field_points_vtk = numpy_to_vtk(self.field_points_np, deep=False)
self.field_points_vtk.SetName("field_points")
self.field_poly.GetPoints().SetData(self.field_points_vtk)
self.field_mapper = vtk.vtkPolyDataMapper()
self.field_mapper.SetInputData(self.field_poly)
self.field_actor = vtk.vtkActor()
self.field_actor.SetMapper(self.field_mapper)
self.field_actor.GetProperty().SetColor(0.1, 0.25, 0.9)
self.field_actor.GetProperty().SetLineWidth(2.0)
self.renderer.AddActor(self.field_actor)
# Electrode actors (spheres).
self.electrode_actors = []
self._electrode_actor_to_index = {}
for i in range(self.N):
src = vtk.vtkSphereSource()
src.SetThetaResolution(16)
src.SetPhiResolution(16)
src.SetRadius(0.05 * self.cfg.a)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(src.GetOutputPort())
actor = vtk.vtkActor()
actor.SetMapper(mapper)
actor.SetVisibility(False)
self.renderer.AddActor(actor)
self.electrode_actors.append((src, actor))
self._electrode_actor_to_index[actor] = i
# Text overlay (help + status).
self.text = vtk.vtkTextActor()
_setup_text_actor(self.text, x=0.01, y=0.99, font_size=16, top=True)
self.renderer.AddActor2D(self.text)
# Editable numeric "textbox".
self.input_text = vtk.vtkTextActor()
_setup_text_actor(self.input_text, x=0.01, y=0.01, font_size=16, top=False)
self.renderer.AddActor2D(self.input_text)
# Current slider for the selected electrode.
rep = vtk.vtkSliderRepresentation2D()
rep.SetMinimumValue(-self.cfg.current_slider_max_A)
rep.SetMaximumValue(+self.cfg.current_slider_max_A)
rep.SetValue(0.0)
rep.SetTitleText("Selected electrode current I [A]")
rep.SetLabelFormat("%0.0f")
rep.SetSliderLength(0.02)
rep.SetSliderWidth(0.03)
rep.SetTubeWidth(0.006)
rep.SetEndCapLength(0.01)
rep.SetEndCapWidth(0.03)
rep.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep.GetPoint1Coordinate().SetValue(0.10, 0.06)
rep.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep.GetPoint2Coordinate().SetValue(0.55, 0.06)
self.slider_rep = rep
self.slider = vtk.vtkSliderWidget()
self.slider.SetInteractor(self.interactor)
self.slider.SetRepresentation(rep)
self.slider.EnabledOn()
def on_slider_end(_obj, _evt):
if self.selected is None:
return
val = float(self.slider_rep.GetValue())
self.currents_raw[self.selected] = val
self.update_solution()
self.slider.AddObserver(vtk.vtkCommand.EndInteractionEvent, on_slider_end)
# Picking helpers.
self.cell_picker = vtk.vtkCellPicker()
self.cell_picker.SetTolerance(0.0005)
self.prop_picker = vtk.vtkPropPicker()
# Interactor events.
self.interactor.AddObserver(vtk.vtkCommand.KeyPressEvent, self._on_keypress)
self.interactor.AddObserver(vtk.vtkCommand.LeftButtonPressEvent, self._on_left_click)
# Initial camera.
self.renderer.ResetCamera()
def _help_text(self) -> str:
return (
"Torus electrode GUI (VTK)\n"
"Mouse: rotate/zoom as usual\n"
"Click electrode: select\n"
"Keys:\n"
" a: add SOURCE (next click on surface)\n"
" z: add SINK (next click on surface)\n"
" m: move selected (next click on surface)\n"
" d: delete selected\n"
" tab: cycle selected\n"
" c: cycle surface scalar (|K|, V, s, Kθ, Kφ)\n"
" f: toggle field lines\n"
" b: toggle external toroidal field (ideal 1/R)\n"
" p: toggle external poloidal field (tokamak-like 1/R)\n"
" [/]: decrease/increase Bext0\n"
" ,/. : decrease/increase Bpol0\n"
" r: recompute\n"
" e: export ParaView (.vtu/.vtm)\n"
" i (or v): type selected I\n"
" s: save screenshot\n"
)
def _status_text(self) -> str:
n_active = int(np.sum(self.active))
sel = self.selected
if sel is None:
sel_txt = "none"
else:
Iraw = float(self.currents_raw[sel])
Iproj = None
if self._Iproj_cache is not None and sel < self._Iproj_cache.size:
Iproj = float(self._Iproj_cache[sel])
if Iproj is None:
sel_txt = f"{sel} Iraw={Iraw:+.0f} A (active={self.active[sel]:.0f})"
else:
sel_txt = (
f"{sel} Iraw={Iraw:+.0f} A Iproj={Iproj:+.0f} A (active={self.active[sel]:.0f})"
)
return (
f"Active electrodes: {n_active}/{self.N}\n"
f"Selected: {sel_txt}\n"
f"Mode: {self.mode}\n"
f"Scalar: {self.scalar_name}\n"
f"bg toroidal: {'ON' if self.include_bg_tor else 'OFF'} Bext0={self.Bext0:.3g} T "
f"bg poloidal: {'ON' if self.include_bg_pol else 'OFF'} Bpol0={self.Bpol0:.3g} T "
f"(at R0={self.cfg.R0:g} m)\n"
f"sigma_theta={self.cfg.sigma_theta:.3f} sigma_phi={self.cfg.sigma_phi:.3f}\n"
f"fieldlines: {self.cfg.n_fieldlines} steps: {self.cfg.fieldline_steps} ds: {self.cfg.fieldline_step_size_m}\n"
)
def _update_text(self) -> None:
self.text.SetInput(self._help_text() + "\n" + self._status_text())
if self._edit_mode == "none":
self.input_text.SetInput("Type: press 'i' (or 'v') to enter the selected electrode current I.")
else:
self.input_text.SetInput(
f"Input [current]: {self._edit_buffer} (Enter=apply, Esc=cancel)"
)
def _project_currents(self) -> np.ndarray:
mask = self.active.astype(float)
I = self.currents_raw * mask
n = float(np.sum(mask))
if n <= 0:
return np.zeros_like(I)
mean = float(np.sum(I) / n)
return I - mean * mask
def _apply_surface_scalar(self) -> None:
if not self._cache:
return
if self.scalar_name == "|K|":
scal = self._cache["Kmag"].reshape((-1,))
elif self.scalar_name == "V":
scal = self._cache["V"].reshape((-1,))
elif self.scalar_name == "s":
scal = self._cache["s"].reshape((-1,))
elif self.scalar_name == "K_theta":
scal = self._cache["Ktheta"].reshape((-1,))
elif self.scalar_name == "K_phi":
scal = self._cache["Kphi"].reshape((-1,))
else:
raise ValueError(self.scalar_name)
# Update VTK scalar array in-place.
self.surface_scalars_np[:] = scal.astype(np.float32, copy=False)
self.surface_scalars_vtk.Modified()
self.torus_poly.Modified()
# Adjust color range robustly.
smin = float(np.nanmin(scal))
smax = float(np.nanmax(scal))
if not np.isfinite(smin) or not np.isfinite(smax) or smin == smax:
smin, smax = 0.0, 1.0
if self.scalar_name in ("s", "K_theta", "K_phi", "V"):
vmax = max(abs(smin), abs(smax))
smin, smax = -vmax, vmax
self.torus_mapper.SetScalarRange(smin, smax)
def _update_electrode_actors(self) -> None:
Iproj = self._Iproj_cache
if Iproj is None:
Iproj = self._project_currents()
for i in range(self.N):
src, actor = self.electrode_actors[i]
if self.active[i] <= 0.0:
actor.SetVisibility(False)
continue
actor.SetVisibility(True)
p = torus_xyz(self.cfg.R0, self.cfg.a, self.theta_src[i], self.phi_src[i])
src.SetCenter(float(p[0]), float(p[1]), float(p[2]))
I = float(Iproj[i])
if i == self.selected:
actor.GetProperty().SetColor(1.0, 0.8, 0.2)
else:
if I > 0:
actor.GetProperty().SetColor(0.85, 0.15, 0.15)
elif I < 0:
actor.GetProperty().SetColor(0.15, 0.25, 0.85)
else:
actor.GetProperty().SetColor(0.4, 0.4, 0.4)
r0 = 0.04 * self.cfg.a
r = r0 * (0.6 + 0.8 * min(abs(I) / self.cfg.current_slider_max_A, 1.0))
src.SetRadius(float(r))
src.Modified()
def _compute_state(
self,
theta_src: jnp.ndarray,
phi_src: jnp.ndarray,
currents_raw: jnp.ndarray,
active: jnp.ndarray,
Bext0: jnp.ndarray,
Bpol0: jnp.ndarray,
include_bg_tor: jnp.ndarray,
include_bg_pol: jnp.ndarray,
compute_traj: jnp.ndarray,
):
# Project currents to net-zero over active electrodes.
mask = active
I = currents_raw * mask
n = jnp.sum(mask)
mean = jnp.where(n > 0, jnp.sum(I) / n, 0.0)
I = I - mean * mask
s = deposit_current_sources(
self.surface,
theta_src=theta_src,
phi_src=phi_src,
currents=I,
sigma_theta=self.cfg.sigma_theta,
sigma_phi=self.cfg.sigma_phi,
)
if float(self.cfg.sigma_s) <= 0.0:
raise ValueError("sigma_s must be > 0 for the electrode model.")
V, _ = solve_current_potential(
self.surface, s / float(self.cfg.sigma_s), tol=self.cfg.cg_tol, maxiter=self.cfg.cg_maxiter
)
K = surface_current_from_potential(self.surface, V, sigma_s=self.cfg.sigma_s)
Kmag = jnp.linalg.norm(K, axis=-1)
Ktheta = jnp.sum(K * self._e_theta, axis=-1)
Kphi = jnp.sum(K * self._e_phi, axis=-1)
def B_fn(xyz: jnp.ndarray) -> jnp.ndarray:
B = biot_savart_surface(self.surface, K, xyz, eps=self.cfg.biot_savart_eps)
Btor = jnp.asarray(include_bg_tor, dtype=B.dtype) * jnp.asarray(Bext0, dtype=B.dtype)
Bpol = jnp.asarray(include_bg_pol, dtype=B.dtype) * jnp.asarray(Bpol0, dtype=B.dtype)
B = B + tokamak_like_field(xyz, B_tor0=Btor, B_pol0=Bpol, R0=float(self.cfg.R0))
return B
n_steps = self.cfg.fieldline_steps
n_lines = self.cfg.n_fieldlines
def do_trace(_):
return trace_field_lines_batch(
B_fn,
self._seeds,
step_size=self.cfg.fieldline_step_size_m,
n_steps=n_steps,
normalize=True,
)
def no_trace(_):
return jnp.zeros((n_steps + 1, n_lines, 3), dtype=jnp.float64)
traj = jax.lax.cond(compute_traj, do_trace, no_trace, operand=None)
return V, s, Kmag, Ktheta, Kphi, traj, I
[docs]
def update_solution(self) -> None:
t0 = time.perf_counter()
self._update_text()
self.window.Render()
th = jnp.asarray(self.theta_src)
ph = jnp.asarray(self.phi_src)
Iraw = jnp.asarray(self.currents_raw)
act = jnp.asarray(self.active)
V, s, Kmag, Ktheta, Kphi, traj, Iproj = self._compute_jit(
th,
ph,
Iraw,
act,
jnp.asarray(self.Bext0, dtype=jnp.float64),
jnp.asarray(self.Bpol0, dtype=jnp.float64),
jnp.asarray(self.include_bg_tor),
jnp.asarray(self.include_bg_pol),
jnp.asarray(self.show_fieldlines),
)
V.block_until_ready()
t1 = time.perf_counter()
# Cache computed state for fast GUI updates (cycle scalars, toggle lines, selection).
self._cache = {
"V": np.asarray(V, dtype=np.float32),
"s": np.asarray(s, dtype=np.float32),
"Kmag": np.asarray(Kmag, dtype=np.float32),
"Ktheta": np.asarray(Ktheta, dtype=np.float32),
"Kphi": np.asarray(Kphi, dtype=np.float32),
}
self._traj_cache = np.asarray(traj, dtype=np.float32)
self._Iproj_cache = np.asarray(Iproj, dtype=float)
self._apply_surface_scalar()
self._update_electrode_actors()
# Update field lines geometry.
self.field_actor.SetVisibility(bool(self.show_fieldlines))
if self.show_fieldlines and self._traj_cache is not None:
self.field_points_np[:] = self._traj_cache.reshape((-1, 3))
self.field_points_vtk.Modified()
self.field_poly.Modified()
if self.selected is not None:
self.slider_rep.SetValue(float(self.currents_raw[self.selected]))
self._update_text()
self.window.Render()
t2 = time.perf_counter()
print(
"update: solve+trace {:.3f}s, total {:.3f}s (scalar={}, active={})".format(
t1 - t0, t2 - t0, self.scalar_name, int(np.sum(self.active))
)
)
def _select_next(self) -> None:
active_idx = np.flatnonzero(self.active > 0.0)
if active_idx.size == 0:
self.selected = None
return
if self.selected is None or self.selected not in active_idx:
self.selected = int(active_idx[0])
return
k = int(np.where(active_idx == self.selected)[0][0])
self.selected = int(active_idx[(k + 1) % active_idx.size])
def _delete_selected(self) -> None:
if self.selected is None:
return
i = self.selected
self.active[i] = 0.0
self.currents_raw[i] = 0.0
self.theta_src[i] = 0.0
self.phi_src[i] = 0.0
self._select_next()
self.update_solution()
def _add_electrode(self, theta: float, phi: float, current: float) -> None:
free = np.flatnonzero(self.active <= 0.0)
if free.size == 0:
print("No free electrode slots; increase n_electrodes_max.")
return
i = int(free[0])
self.theta_src[i] = float(theta)
self.phi_src[i] = float(phi)
self.currents_raw[i] = float(current)
self.active[i] = 1.0
self.selected = i
self.slider_rep.SetValue(float(self.currents_raw[i]))
self.update_solution()
def _on_keypress(self, _obj, _evt) -> None:
key_sym = self.interactor.GetKeySym()
key_code = self.interactor.GetKeyCode()
if self._edit_mode != "none":
if key_sym in ("Escape",):
self._edit_mode = "none"
self._edit_buffer = ""
self._update_text()
self.window.Render()
return
if key_sym in ("Return", "KP_Enter"):
try:
val = float(self._edit_buffer.strip())
except Exception:
print(f"Could not parse number: {self._edit_buffer!r}")
return
if self.selected is not None:
self.currents_raw[self.selected] = float(val)
self.slider_rep.SetValue(float(self.currents_raw[self.selected]))
self._edit_mode = "none"
self._edit_buffer = ""
self.update_solution()
return
if key_sym in ("BackSpace", "Delete"):
self._edit_buffer = self._edit_buffer[:-1]
self._update_text()
self.window.Render()
return
if key_code and key_code in "0123456789+-eE.":
self._edit_buffer = self._edit_buffer + key_code
self._update_text()
self.window.Render()
return
return
if key_sym in ("i", "I", "v", "V"):
if self.selected is None:
return
self._edit_mode = "current"
self._edit_buffer = f"{float(self.currents_raw[self.selected]):.6g}"
self._update_text()
self.window.Render()
return
if key_sym in ("a", "A"):
self.mode = "add_source"
elif key_sym in ("z", "Z"):
self.mode = "add_sink"
elif key_sym in ("m", "M"):
self.mode = "move"
elif key_sym in ("d", "Delete", "BackSpace"):
self._delete_selected()
return
elif key_sym in ("Tab",):
self._select_next()
self._update_electrode_actors()
elif key_sym in ("c", "C"):
order: list[ScalarName] = ["|K|", "V", "s", "K_theta", "K_phi"]
k = order.index(self.scalar_name)
self.scalar_name = order[(k + 1) % len(order)]
self._apply_surface_scalar()
elif key_sym in ("f", "F"):
self.show_fieldlines = not self.show_fieldlines
if self.show_fieldlines:
self.update_solution()
return
self.field_actor.SetVisibility(False)
elif key_sym in ("b", "B"):
self.include_bg_tor = not self.include_bg_tor
print(
f"External toroidal field (1/R): {'ON' if self.include_bg_tor else 'OFF'} "
f"(Bext0={self.Bext0:.3g} T at R0={self.cfg.R0:g} m)"
)
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("p", "P"):
self.include_bg_pol = not self.include_bg_pol
print(
f"External poloidal field (1/R): {'ON' if self.include_bg_pol else 'OFF'} "
f"(Bpol0={self.Bpol0:.3g} T at R0={self.cfg.R0:g} m)"
)
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("bracketleft",):
self.Bext0 = float(self.Bext0) / 1.2
print(f"Bext0 -> {self.Bext0:.6g} T")
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("bracketright",):
self.Bext0 = float(self.Bext0) * 1.2
print(f"Bext0 -> {self.Bext0:.6g} T")
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("comma",):
self.Bpol0 = float(self.Bpol0) / 1.2
print(f"Bpol0 -> {self.Bpol0:.6g} T")
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("period",):
self.Bpol0 = float(self.Bpol0) * 1.2
print(f"Bpol0 -> {self.Bpol0:.6g} T")
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("r", "R"):
self.update_solution()
return
elif key_sym in ("e", "E"):
self._export_paraview()
elif key_sym in ("s", "S"):
self._save_screenshot()
self._update_text()
self.window.Render()
def _on_left_click(self, _obj, _evt) -> None:
x, y = self.interactor.GetEventPosition()
# 1) Try selecting an electrode.
self.prop_picker.Pick(x, y, 0, self.renderer)
actor = self.prop_picker.GetActor()
if actor in self._electrode_actor_to_index:
self.selected = int(self._electrode_actor_to_index[actor])
self.slider_rep.SetValue(float(self.currents_raw[self.selected]))
self._update_electrode_actors()
self._update_text()
self.window.Render()
return
# 2) Add/move electrode by picking torus surface.
if self.mode in ("add_source", "add_sink", "move"):
if not self.cell_picker.Pick(x, y, 0, self.renderer):
self.mode = "none"
self._update_text()
self.window.Render()
return
p = np.array(self.cell_picker.GetPickPosition(), dtype=float)
theta, phi = torus_angles_from_point(self.cfg.R0, p)
if self.mode == "move":
if self.selected is not None and self.active[self.selected] > 0:
self.theta_src[self.selected] = theta
self.phi_src[self.selected] = phi
self.mode = "none"
self.update_solution()
else:
self.mode = "none"
elif self.mode == "add_source":
self.mode = "none"
self._add_electrode(theta, phi, +self.cfg.current_default_A)
elif self.mode == "add_sink":
self.mode = "none"
self._add_electrode(theta, phi, -self.cfg.current_default_A)
return
# Fall back to default camera interaction.
self.interactor.GetInteractorStyle().OnLeftButtonDown()
def _save_screenshot(self) -> None:
outdir = Path("figures/gui_screenshots")
outdir.mkdir(parents=True, exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")
path = outdir / f"torus_gui_{ts}.png"
w2i = vtk.vtkWindowToImageFilter()
w2i.SetInput(self.window)
w2i.Update()
writer = vtk.vtkPNGWriter()
writer.SetFileName(str(path))
writer.SetInputConnection(w2i.GetOutputPort())
writer.Write()
print(f"Saved screenshot: {path}")
def _export_paraview(self) -> None:
from .paraview import fieldlines_to_vtu, point_cloud_to_vtu, torus_surface_to_vtu, write_vtm, write_vtu
ts = time.strftime("%Y%m%d_%H%M%S")
outdir = Path("paraview") / f"gui_torus_electrodes_{ts}"
outdir.mkdir(parents=True, exist_ok=True)
V = self._cache.get("V")
s = self._cache.get("s")
Ktheta = self._cache.get("Ktheta")
Kphi = self._cache.get("Kphi")
Kmag = self._cache.get("Kmag")
if V is None or s is None or Ktheta is None or Kphi is None or Kmag is None:
print("ParaView export: no cached solution yet.")
return
e_theta = np.asarray(self._e_theta, dtype=float)
e_phi = np.asarray(self._e_phi, dtype=float)
K_vec = Ktheta[..., None] * e_theta + Kphi[..., None] * e_phi
surf = write_vtu(
outdir / "winding_surface.vtu",
torus_surface_to_vtu(
surface=self.surface,
point_data={
"V": V.reshape(-1),
"s": s.reshape(-1),
"K": K_vec.reshape(-1, 3),
"Ktheta": Ktheta.reshape(-1),
"Kphi": Kphi.reshape(-1),
"|K|": Kmag.reshape(-1),
},
),
)
blocks: dict[str, str] = {"winding_surface": surf.name}
active = np.flatnonzero(self.active > 0.0)
if active.size > 0:
xyz = torus_xyz(self.cfg.R0, self.cfg.a, self.theta_src[active], self.phi_src[active])
I = (
np.asarray(self._Iproj_cache, dtype=float)[active]
if self._Iproj_cache is not None
else np.asarray(self.currents_raw, dtype=float)[active]
)
elec = write_vtu(
outdir / "electrodes.vtu",
point_cloud_to_vtu(
points=np.asarray(xyz, dtype=float),
point_data={"I_A": I, "sign_I": np.sign(I)},
),
)
blocks["electrodes"] = elec.name
if self.show_fieldlines and self._traj_cache is not None:
traj_pv = np.transpose(self._traj_cache, (1, 0, 2))
fl = write_vtu(outdir / "fieldlines.vtu", fieldlines_to_vtu(traj=traj_pv))
blocks["fieldlines"] = fl.name
scene = write_vtm(outdir / "scene.vtm", blocks)
print(f"Saved ParaView scene: {scene}")
[docs]
def run(self) -> None:
print("Starting GUI. Close the window to exit.")
self._update_text()
self.window.Render()
self.interactor.Initialize()
self.interactor.Start()
[docs]
def run_torus_electrode_gui(
*,
cfg: GUIConfig = GUIConfig(),
initial_electrodes: dict | None = None,
) -> None: # pragma: no cover
"""Entry point for examples."""
_require_vtk()
print("Interactive torus electrode GUI")
print(f" R0={cfg.R0} a={cfg.a} n_theta={cfg.n_theta} n_phi={cfg.n_phi}")
print(
f" sigma_theta={cfg.sigma_theta} sigma_phi={cfg.sigma_phi} "
f"n_lines={cfg.n_fieldlines} steps={cfg.fieldline_steps} ds={cfg.fieldline_step_size_m}"
)
print(f" mu0={float(MU0):.6e}")
app = TorusElectrodeGUI(cfg, initial_electrodes=initial_electrodes)
app.run()
[docs]
class TorusCutVoltageGUI: # pragma: no cover
"""VTK GUI: a toroidal cut voltage drives poloidal current; optional extra electrodes add sources/sinks."""
def __init__(self, cfg: CutGUIConfig, *, initial_electrodes: dict | None = None):
_require_vtk()
jax.config.update("jax_enable_x64", True)
self.cfg = cfg
self.surface: TorusSurface = make_torus_surface(
R0=cfg.R0, a=cfg.a, n_theta=cfg.n_theta, n_phi=cfg.n_phi
)
# Axisymmetric cut-driven solution uses ∂θV = C/R, with C fixed by V_cut.
Rtheta = self.surface.R[:, 0] # (Nθ,)
self._I1 = jnp.sum((1.0 / Rtheta) * self.surface.dtheta) # ∮ dθ/R(θ)
f = _cut_phase_theta_np(
theta=np.asarray(self.surface.theta),
R0=cfg.R0,
a=cfg.a,
theta_cut=cfg.theta_cut,
)
self._f_theta = jnp.asarray(f, dtype=jnp.float64) # (Nθ,)
# Electrode state (fixed size to avoid recompilation when user adds/removes).
self.N = cfg.n_electrodes_max
self.theta_src = np.zeros((self.N,), dtype=float)
self.phi_src = np.zeros((self.N,), dtype=float)
self.currents_raw = np.zeros((self.N,), dtype=float)
self.active = np.zeros((self.N,), dtype=float)
if initial_electrodes is not None:
th = np.asarray(initial_electrodes.get("theta", []), dtype=float)
ph = np.asarray(initial_electrodes.get("phi", []), dtype=float)
I = np.asarray(initial_electrodes.get("I", []), dtype=float)
n0 = min(self.N, th.size, ph.size, I.size)
self.theta_src[:n0] = th[:n0]
self.phi_src[:n0] = ph[:n0]
self.currents_raw[:n0] = I[:n0]
self.active[:n0] = 1.0
self.selected: int | None = int(np.argmax(self.active)) if np.any(self.active) else None
self.mode: Literal["none", "add_source", "add_sink", "move"] = "none"
# Precompute unit vectors for K decomposition.
self._e_theta = self.surface.r_theta / self.surface.a
self._e_phi = self.surface.r_phi / jnp.sqrt(self.surface.G)[..., None]
# Field line seeds (inside the torus).
theta_seed = jnp.linspace(0.0, 2 * jnp.pi, cfg.n_fieldlines, endpoint=False)
rho = 0.5 * cfg.a
R = cfg.R0 + rho * jnp.cos(theta_seed)
Z = rho * jnp.sin(theta_seed)
self._seeds = jnp.stack([R, jnp.zeros_like(R), Z], axis=-1)
self.scalar_name: ScalarName = "|K|"
self.show_fieldlines = True
self.Bext0 = float(self.cfg.Bext0)
self.Bpol0 = float(self.cfg.Bpol0)
self.include_bg_tor = bool(self.cfg.bg_field_default_on and self.Bext0 != 0.0)
self.include_bg_pol = bool(self.cfg.bg_poloidal_default_on and self.Bpol0 != 0.0)
self.V_cut = float(cfg.V_cut_default)
# Cached computed state (numpy) so we can update visualization without recomputing.
self._cache: dict[str, np.ndarray] = {}
self._traj_cache: np.ndarray | None = None
self._Iproj_cache: np.ndarray | None = None
# Text entry ("textbox") state.
self._edit_mode: Literal["none", "V_cut", "current"] = "none"
self._edit_buffer: str = ""
self._build_scene()
self._compute_jit = jax.jit(self._compute_state)
print("Compiling JAX pipeline (first update may take a moment)...")
self.update_solution()
def _build_scene(self) -> None:
cfg = self.cfg
self.renderer = vtk.vtkRenderer()
self.renderer.SetBackground(1.0, 1.0, 1.0)
self.window = vtk.vtkRenderWindow()
self.window.AddRenderer(self.renderer)
self.window.SetSize(*cfg.window_size)
self.interactor = vtk.vtkRenderWindowInteractor()
self.interactor.SetRenderWindow(self.window)
style = vtk.vtkInteractorStyleTrackballCamera()
self.interactor.SetInteractorStyle(style)
xyz = np.asarray(self.surface.r)
self.torus_poly = _build_torus_polydata(xyz)
self.surface_scalars_np = np.zeros((cfg.n_theta * cfg.n_phi,), dtype=np.float32)
self.surface_scalars_vtk = numpy_to_vtk(self.surface_scalars_np, deep=False)
self.surface_scalars_vtk.SetName("scalar")
self.torus_poly.GetPointData().SetScalars(self.surface_scalars_vtk)
self.lut = vtk.vtkLookupTable()
self.lut.SetNumberOfTableValues(256)
self.lut.Build()
self.torus_mapper = vtk.vtkPolyDataMapper()
self.torus_mapper.SetInputData(self.torus_poly)
self.torus_mapper.SetLookupTable(self.lut)
self.torus_mapper.SetScalarModeToUsePointData()
self.torus_mapper.ScalarVisibilityOn()
self.torus_actor = vtk.vtkActor()
self.torus_actor.SetMapper(self.torus_mapper)
self.torus_actor.GetProperty().SetOpacity(cfg.surface_opacity)
self.torus_actor.GetProperty().SetInterpolationToPhong()
self.torus_actor.GetProperty().SetSpecular(0.2)
self.torus_actor.GetProperty().SetSpecularPower(30.0)
self.renderer.AddActor(self.torus_actor)
# Axis curve for reference.
axis_phi = np.linspace(0, 2 * np.pi, 200, endpoint=True)
axis = np.stack([cfg.R0 * np.cos(axis_phi), cfg.R0 * np.sin(axis_phi), 0.0 * axis_phi], axis=-1)
axis_poly = vtk.vtkPolyData()
axis_points = vtk.vtkPoints()
axis_points.SetData(numpy_to_vtk(axis, deep=True))
axis_poly.SetPoints(axis_points)
axis_lines = vtk.vtkCellArray()
pl = vtk.vtkPolyLine()
pl.GetPointIds().SetNumberOfIds(axis.shape[0])
for i in range(axis.shape[0]):
pl.GetPointIds().SetId(i, i)
axis_lines.InsertNextCell(pl)
axis_poly.SetLines(axis_lines)
axis_mapper = vtk.vtkPolyDataMapper()
axis_mapper.SetInputData(axis_poly)
axis_actor = vtk.vtkActor()
axis_actor.SetMapper(axis_mapper)
axis_actor.GetProperty().SetColor(0.0, 0.0, 0.0)
axis_actor.GetProperty().SetLineWidth(2.0)
self.renderer.AddActor(axis_actor)
# Cut curve + terminals for reference (where V jumps in the visualization).
cut_phi = np.linspace(0.0, 2 * np.pi, 240, endpoint=True)
cut_theta = float(cfg.theta_cut) % (2 * np.pi)
cut = torus_xyz(cfg.R0, cfg.a, cut_theta * np.ones_like(cut_phi), cut_phi)
def _polyline_actor(points_xyz: np.ndarray, *, rgb: tuple[float, float, float], width: float) -> VtkActor:
poly = vtk.vtkPolyData()
pts = vtk.vtkPoints()
pts.SetData(numpy_to_vtk(points_xyz, deep=True))
poly.SetPoints(pts)
lines = vtk.vtkCellArray()
pl = vtk.vtkPolyLine()
pl.GetPointIds().SetNumberOfIds(points_xyz.shape[0])
for i in range(points_xyz.shape[0]):
pl.GetPointIds().SetId(i, i)
lines.InsertNextCell(pl)
poly.SetLines(lines)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputData(poly)
actor = vtk.vtkActor()
actor.SetMapper(mapper)
actor.GetProperty().SetColor(*rgb)
actor.GetProperty().SetLineWidth(width)
return actor
self.cut_actor = _polyline_actor(cut, rgb=(0.15, 0.15, 0.15), width=3.0)
self.renderer.AddActor(self.cut_actor)
# Two nearby rings show the "before" and "after" sides of the cut.
delta = 0.5 * float(self.surface.dtheta)
th_before = (cut_theta - delta) % (2 * np.pi)
th_after = (cut_theta + 0.0) % (2 * np.pi)
ring_before = torus_xyz(cfg.R0, cfg.a, th_before * np.ones_like(cut_phi), cut_phi)
ring_after = torus_xyz(cfg.R0, cfg.a, th_after * np.ones_like(cut_phi), cut_phi)
self.cut_before_actor = _polyline_actor(ring_before, rgb=(0.85, 0.15, 0.15), width=2.5)
self.cut_after_actor = _polyline_actor(ring_after, rgb=(0.15, 0.25, 0.85), width=2.5)
self.renderer.AddActor(self.cut_before_actor)
self.renderer.AddActor(self.cut_after_actor)
# Field lines actor (updated each solve).
n_pts_line = self.cfg.fieldline_steps + 1
self.field_poly = _build_fieldlines_polydata(self.cfg.n_fieldlines, n_pts_line)
self.field_points_np = np.zeros((self.cfg.n_fieldlines * n_pts_line, 3), dtype=np.float32)
self.field_points_vtk = numpy_to_vtk(self.field_points_np, deep=False)
self.field_points_vtk.SetName("field_points")
self.field_poly.GetPoints().SetData(self.field_points_vtk)
self.field_mapper = vtk.vtkPolyDataMapper()
self.field_mapper.SetInputData(self.field_poly)
self.field_actor = vtk.vtkActor()
self.field_actor.SetMapper(self.field_mapper)
self.field_actor.GetProperty().SetColor(0.1, 0.25, 0.9)
self.field_actor.GetProperty().SetLineWidth(2.0)
self.renderer.AddActor(self.field_actor)
# Electrode actors (spheres).
self.electrode_actors = []
self._electrode_actor_to_index = {}
for i in range(self.N):
src = vtk.vtkSphereSource()
src.SetThetaResolution(16)
src.SetPhiResolution(16)
src.SetRadius(0.05 * self.cfg.a)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(src.GetOutputPort())
actor = vtk.vtkActor()
actor.SetMapper(mapper)
actor.SetVisibility(False)
self.renderer.AddActor(actor)
self.electrode_actors.append((src, actor))
self._electrode_actor_to_index[actor] = i
# Text overlays: help+status (top) and an editable numeric "textbox" (bottom).
self.text = vtk.vtkTextActor()
_setup_text_actor(self.text, x=0.01, y=0.99, font_size=16, top=True)
self.renderer.AddActor2D(self.text)
self.input_text = vtk.vtkTextActor()
_setup_text_actor(self.input_text, x=0.01, y=0.01, font_size=16, top=False)
self.renderer.AddActor2D(self.input_text)
# V_cut slider.
rep_v = vtk.vtkSliderRepresentation2D()
rep_v.SetMinimumValue(-self.cfg.V_cut_slider_max)
rep_v.SetMaximumValue(+self.cfg.V_cut_slider_max)
rep_v.SetValue(self.V_cut)
rep_v.SetTitleText("Cut voltage V_cut [arb] (press 'v' to type)")
rep_v.SetLabelFormat("%0.3f")
rep_v.SetSliderLength(0.02)
rep_v.SetSliderWidth(0.03)
rep_v.SetTubeWidth(0.006)
rep_v.SetEndCapLength(0.01)
rep_v.SetEndCapWidth(0.03)
rep_v.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep_v.GetPoint1Coordinate().SetValue(0.10, 0.06)
rep_v.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep_v.GetPoint2Coordinate().SetValue(0.55, 0.06)
self.V_slider_rep = rep_v
self.V_slider = vtk.vtkSliderWidget()
self.V_slider.SetInteractor(self.interactor)
self.V_slider.SetRepresentation(rep_v)
self.V_slider.EnabledOn()
def on_v_slider_end(_obj, _evt):
self.V_cut = float(self.V_slider_rep.GetValue())
self.update_solution()
self.V_slider.AddObserver(vtk.vtkCommand.EndInteractionEvent, on_v_slider_end)
# Current slider for the selected electrode.
rep_i = vtk.vtkSliderRepresentation2D()
rep_i.SetMinimumValue(-self.cfg.current_slider_max_A)
rep_i.SetMaximumValue(+self.cfg.current_slider_max_A)
rep_i.SetValue(0.0)
rep_i.SetTitleText("Selected electrode current I [A] (press 'i' to type)")
rep_i.SetLabelFormat("%0.0f")
rep_i.SetSliderLength(0.02)
rep_i.SetSliderWidth(0.03)
rep_i.SetTubeWidth(0.006)
rep_i.SetEndCapLength(0.01)
rep_i.SetEndCapWidth(0.03)
rep_i.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep_i.GetPoint1Coordinate().SetValue(0.58, 0.06)
rep_i.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep_i.GetPoint2Coordinate().SetValue(0.98, 0.06)
self.I_slider_rep = rep_i
self.I_slider = vtk.vtkSliderWidget()
self.I_slider.SetInteractor(self.interactor)
self.I_slider.SetRepresentation(rep_i)
self.I_slider.EnabledOn()
def on_i_slider_end(_obj, _evt):
if self.selected is None:
return
val = float(self.I_slider_rep.GetValue())
self.currents_raw[self.selected] = val
self.update_solution()
self.I_slider.AddObserver(vtk.vtkCommand.EndInteractionEvent, on_i_slider_end)
# Picking helpers.
self.cell_picker = vtk.vtkCellPicker()
self.cell_picker.SetTolerance(0.0005)
self.prop_picker = vtk.vtkPropPicker()
# Interactor events.
self.interactor.AddObserver(vtk.vtkCommand.KeyPressEvent, self._on_keypress)
self.interactor.AddObserver(vtk.vtkCommand.LeftButtonPressEvent, self._on_left_click)
self.renderer.ResetCamera()
def _help_text(self) -> str:
return (
"Torus cut+electrodes GUI (VTK)\n"
"Mouse: rotate/zoom as usual\n"
"Electrodes: red=source (+I), blue=sink (-I), yellow=selected\n"
"Cut terminals: red ring = higher V side, blue ring = lower V side; thick gray ring = cut location\n"
"Keys:\n"
" a/z/m: add source / add sink / move selected (next click)\n"
" d: delete selected tab: cycle selected\n"
" c: cycle scalar (|K|, V, s, Kθ, Kφ)\n"
" f: toggle field lines r: recompute\n"
" b: toggle external toroidal field (ideal 1/R)\n"
" p: toggle external poloidal field (tokamak-like 1/R)\n"
" [/]: decrease/increase Bext0 ,/. : decrease/increase Bpol0\n"
" v: type V_cut i: type selected I\n"
" e: export ParaView (.vtu/.vtm)\n"
" s: save screenshot\n"
)
def _status_text(self) -> str:
n_active = int(np.sum(self.active))
if self.selected is None:
sel_txt = "none"
else:
Iraw = float(self.currents_raw[self.selected])
Iproj = None
if self._Iproj_cache is not None and self.selected < self._Iproj_cache.size:
Iproj = float(self._Iproj_cache[self.selected])
if Iproj is None:
sel_txt = f"{self.selected} Iraw={Iraw:+.0f} A"
else:
sel_txt = f"{self.selected} Iraw={Iraw:+.0f} A Iproj={Iproj:+.0f} A"
return (
f"V_cut={self.V_cut:+.3f} theta_cut={float(self.cfg.theta_cut)%(2*np.pi):.3f} sigma_s={self.cfg.sigma_s:.3f}\n"
f"Active electrodes: {n_active}/{self.N} Selected: {sel_txt} Mode: {self.mode}\n"
f"Scalar: {self.scalar_name}\n"
f"bg toroidal: {'ON' if self.include_bg_tor else 'OFF'} Bext0={self.Bext0:.3g} T "
f"bg poloidal: {'ON' if self.include_bg_pol else 'OFF'} Bpol0={self.Bpol0:.3g} T "
f"(at R0={self.cfg.R0:g} m)\n"
f"sigma_theta={self.cfg.sigma_theta:.3f} sigma_phi={self.cfg.sigma_phi:.3f}\n"
f"fieldlines: {self.cfg.n_fieldlines} steps: {self.cfg.fieldline_steps} ds: {self.cfg.fieldline_step_size_m}\n"
)
def _update_text(self) -> None:
self.text.SetInput(self._help_text() + "\n" + self._status_text())
if self._edit_mode == "none":
self.input_text.SetInput("Type: press 'v' (V_cut) or 'i' (selected I) to enter a value.")
else:
self.input_text.SetInput(
f"Input [{self._edit_mode}]: {self._edit_buffer} (Enter=apply, Esc=cancel)"
)
def _project_currents(self) -> np.ndarray:
mask = self.active.astype(float)
I = self.currents_raw * mask
n = float(np.sum(mask))
if n <= 0:
return np.zeros_like(I)
mean = float(np.sum(I) / n)
return I - mean * mask
def _compute_state(
self,
theta_src: jnp.ndarray,
phi_src: jnp.ndarray,
currents_raw: jnp.ndarray,
active: jnp.ndarray,
V_cut: jnp.ndarray,
Bext0: jnp.ndarray,
Bpol0: jnp.ndarray,
include_bg_tor: jnp.ndarray,
include_bg_pol: jnp.ndarray,
compute_traj: jnp.ndarray,
):
# Project currents to net-zero over active electrodes.
mask = active
I = currents_raw * mask
n = jnp.sum(mask)
mean = jnp.where(n > 0, jnp.sum(I) / n, 0.0)
I = I - mean * mask
# Electrode-driven contribution (sources/sinks).
s = deposit_current_sources(
self.surface,
theta_src=theta_src,
phi_src=phi_src,
currents=I,
sigma_theta=self.cfg.sigma_theta,
sigma_phi=self.cfg.sigma_phi,
)
V_e, _ = solve_current_potential(
self.surface,
s / float(self.cfg.sigma_s),
tol=self.cfg.cg_tol,
maxiter=self.cfg.cg_maxiter,
)
K_e = surface_current_from_potential(self.surface, V_e, sigma_s=self.cfg.sigma_s)
# Cut-driven poloidal current (topological drive).
V_cut = jnp.asarray(V_cut, dtype=jnp.float64)
C = V_cut / self._I1
dV_dtheta = (C / self.surface.R) # (Nθ,1)
K_cut = (-self.cfg.sigma_s) * (dV_dtheta / (self.surface.a * self.surface.a))[..., None] * self.surface.r_theta
K = K_cut + K_e
Kmag = jnp.linalg.norm(K, axis=-1)
Ktheta = jnp.sum(K * self._e_theta, axis=-1)
Kphi = jnp.sum(K * self._e_phi, axis=-1)
# Visualization potential: multi-valued cut component + single-valued electrode component.
V_vis = V_e + V_cut * self._f_theta[:, None]
def B_fn(xyz: jnp.ndarray) -> jnp.ndarray:
B = biot_savart_surface(self.surface, K, xyz, eps=self.cfg.biot_savart_eps)
Btor = jnp.asarray(include_bg_tor, dtype=B.dtype) * jnp.asarray(Bext0, dtype=B.dtype)
Bpol = jnp.asarray(include_bg_pol, dtype=B.dtype) * jnp.asarray(Bpol0, dtype=B.dtype)
B = B + tokamak_like_field(xyz, B_tor0=Btor, B_pol0=Bpol, R0=float(self.cfg.R0))
return B
n_steps = self.cfg.fieldline_steps
n_lines = self.cfg.n_fieldlines
def do_trace(_):
return trace_field_lines_batch(
B_fn,
self._seeds,
step_size=self.cfg.fieldline_step_size_m,
n_steps=n_steps,
normalize=True,
)
def no_trace(_):
return jnp.zeros((n_steps + 1, n_lines, 3), dtype=jnp.float64)
traj = jax.lax.cond(compute_traj, do_trace, no_trace, operand=None)
return V_vis, s, Kmag, Ktheta, Kphi, traj, I
def _apply_surface_scalar(self) -> None:
if not self._cache:
return
if self.scalar_name == "|K|":
scal = self._cache["Kmag"].reshape((-1,))
elif self.scalar_name == "V":
scal = self._cache["V"].reshape((-1,))
elif self.scalar_name == "s":
scal = self._cache["s"].reshape((-1,))
elif self.scalar_name == "K_theta":
scal = self._cache["Ktheta"].reshape((-1,))
elif self.scalar_name == "K_phi":
scal = self._cache["Kphi"].reshape((-1,))
else:
raise ValueError(self.scalar_name)
self.surface_scalars_np[:] = scal.astype(np.float32, copy=False)
self.surface_scalars_vtk.Modified()
self.torus_poly.Modified()
smin = float(np.nanmin(scal))
smax = float(np.nanmax(scal))
if not np.isfinite(smin) or not np.isfinite(smax) or smin == smax:
smin, smax = 0.0, 1.0
if self.scalar_name in ("s", "V", "K_theta", "K_phi"):
vmax = max(abs(smin), abs(smax))
smin, smax = -vmax, vmax
self.torus_mapper.SetScalarRange(smin, smax)
def _update_electrode_actors(self) -> None:
Iproj = self._Iproj_cache
if Iproj is None:
Iproj = self._project_currents()
for i in range(self.N):
src, actor = self.electrode_actors[i]
if self.active[i] <= 0.0:
actor.SetVisibility(False)
continue
actor.SetVisibility(True)
p = torus_xyz(self.cfg.R0, self.cfg.a, self.theta_src[i], self.phi_src[i])
src.SetCenter(float(p[0]), float(p[1]), float(p[2]))
I = float(Iproj[i])
if i == self.selected:
actor.GetProperty().SetColor(1.0, 0.8, 0.2)
else:
if I > 0:
actor.GetProperty().SetColor(0.85, 0.15, 0.15)
elif I < 0:
actor.GetProperty().SetColor(0.15, 0.25, 0.85)
else:
actor.GetProperty().SetColor(0.4, 0.4, 0.4)
r0 = 0.04 * self.cfg.a
r = r0 * (0.6 + 0.8 * min(abs(I) / self.cfg.current_slider_max_A, 1.0))
src.SetRadius(float(r))
src.Modified()
if self.selected is not None:
self.I_slider_rep.SetValue(float(self.currents_raw[self.selected]))
def _update_cut_terminal_colors(self) -> None:
# In the visualization, V jumps from ~V_cut to 0 at theta_cut.
# "before" (red by default) is the high-potential side for V_cut>0.
v_before = self.V_cut
v_after = 0.0
if v_before >= v_after:
hi, lo = self.cut_before_actor, self.cut_after_actor
else:
hi, lo = self.cut_after_actor, self.cut_before_actor
hi.GetProperty().SetColor(0.85, 0.15, 0.15)
lo.GetProperty().SetColor(0.15, 0.25, 0.85)
[docs]
def update_solution(self) -> None:
t0 = time.perf_counter()
self._update_text()
self.window.Render()
th = jnp.asarray(self.theta_src)
ph = jnp.asarray(self.phi_src)
Iraw = jnp.asarray(self.currents_raw)
act = jnp.asarray(self.active)
V_vis, s, Kmag, Ktheta, Kphi, traj, Iproj = self._compute_jit(
th,
ph,
Iraw,
act,
jnp.asarray(self.V_cut),
jnp.asarray(self.Bext0),
jnp.asarray(self.Bpol0),
jnp.asarray(self.include_bg_tor),
jnp.asarray(self.include_bg_pol),
jnp.asarray(self.show_fieldlines),
)
V_vis.block_until_ready()
t1 = time.perf_counter()
self._cache = {
"V": np.asarray(V_vis, dtype=np.float32),
"s": np.asarray(s, dtype=np.float32),
"Kmag": np.asarray(Kmag, dtype=np.float32),
"Ktheta": np.asarray(Ktheta, dtype=np.float32),
"Kphi": np.asarray(Kphi, dtype=np.float32),
}
self._traj_cache = np.asarray(traj, dtype=np.float32)
self._Iproj_cache = np.asarray(Iproj, dtype=float)
self._apply_surface_scalar()
self._update_electrode_actors()
self._update_cut_terminal_colors()
self.field_actor.SetVisibility(bool(self.show_fieldlines))
if self.show_fieldlines and self._traj_cache is not None:
self.field_points_np[:] = self._traj_cache.reshape((-1, 3))
self.field_points_vtk.Modified()
self.field_poly.Modified()
self._update_text()
self.window.Render()
t2 = time.perf_counter()
print(
"update: solve+trace {:.3f}s, total {:.3f}s (scalar={}, active={})".format(
t1 - t0, t2 - t0, self.scalar_name, int(np.sum(self.active))
)
)
def _select_next(self) -> None:
active_idx = np.flatnonzero(self.active > 0.0)
if active_idx.size == 0:
self.selected = None
return
if self.selected is None or self.selected not in active_idx:
self.selected = int(active_idx[0])
return
k = int(np.where(active_idx == self.selected)[0][0])
self.selected = int(active_idx[(k + 1) % active_idx.size])
def _delete_selected(self) -> None:
if self.selected is None:
return
i = self.selected
self.active[i] = 0.0
self.currents_raw[i] = 0.0
self.theta_src[i] = 0.0
self.phi_src[i] = 0.0
self._select_next()
self.update_solution()
def _add_electrode(self, theta: float, phi: float, current: float) -> None:
free = np.flatnonzero(self.active <= 0.0)
if free.size == 0:
print("No free electrode slots; increase n_electrodes_max.")
return
i = int(free[0])
self.theta_src[i] = float(theta)
self.phi_src[i] = float(phi)
self.currents_raw[i] = float(current)
self.active[i] = 1.0
self.selected = i
self.I_slider_rep.SetValue(float(self.currents_raw[i]))
self.update_solution()
def _begin_edit(self, mode: Literal["V_cut", "current"]) -> None:
self._edit_mode = mode
if mode == "V_cut":
self._edit_buffer = f"{self.V_cut:.6g}"
else:
if self.selected is None:
self._edit_mode = "none"
self._edit_buffer = ""
return
self._edit_buffer = f"{float(self.currents_raw[self.selected]):.6g}"
self._update_text()
self.window.Render()
def _handle_edit_key(self, key_sym: str, key_code: str) -> bool:
if self._edit_mode == "none":
return False
if key_sym in ("Escape",):
self._edit_mode = "none"
self._edit_buffer = ""
self._update_text()
self.window.Render()
return True
if key_sym in ("Return", "KP_Enter"):
try:
val = float(self._edit_buffer.strip())
except Exception:
print(f"Could not parse number: {self._edit_buffer!r}")
return True
if self._edit_mode == "V_cut":
self.V_cut = float(val)
self.V_slider_rep.SetValue(float(self.V_cut))
self._edit_mode = "none"
self._edit_buffer = ""
self.update_solution()
return True
# current
if self.selected is not None:
self.currents_raw[self.selected] = float(val)
self.I_slider_rep.SetValue(float(self.currents_raw[self.selected]))
self._edit_mode = "none"
self._edit_buffer = ""
self.update_solution()
return True
if key_sym in ("BackSpace", "Delete"):
self._edit_buffer = self._edit_buffer[:-1]
self._update_text()
self.window.Render()
return True
if key_code and key_code in "0123456789+-eE.":
self._edit_buffer = self._edit_buffer + key_code
self._update_text()
self.window.Render()
return True
return True
def _on_keypress(self, _obj, _evt) -> None:
key_sym = self.interactor.GetKeySym()
key_code = self.interactor.GetKeyCode()
if self._handle_edit_key(key_sym, key_code):
return
if key_sym in ("v", "V"):
self._begin_edit("V_cut")
return
if key_sym in ("i", "I"):
self._begin_edit("current")
return
if key_sym in ("a", "A"):
self.mode = "add_source"
elif key_sym in ("z", "Z"):
self.mode = "add_sink"
elif key_sym in ("m", "M"):
self.mode = "move"
elif key_sym in ("d", "Delete", "BackSpace"):
self._delete_selected()
return
elif key_sym in ("Tab",):
self._select_next()
self._update_electrode_actors()
elif key_sym in ("c", "C"):
order: list[ScalarName] = ["|K|", "V", "s", "K_theta", "K_phi"]
k = order.index(self.scalar_name)
self.scalar_name = order[(k + 1) % len(order)]
self._apply_surface_scalar()
elif key_sym in ("f", "F"):
self.show_fieldlines = not self.show_fieldlines
# Need to recompute when turning on.
if self.show_fieldlines:
self.update_solution()
return
self.field_actor.SetVisibility(False)
elif key_sym in ("b", "B"):
self.include_bg_tor = not self.include_bg_tor
print(
f"External toroidal field (1/R): {'ON' if self.include_bg_tor else 'OFF'} "
f"(Bext0={self.Bext0:.3g} T at R0={self.cfg.R0:g} m)"
)
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("p", "P"):
self.include_bg_pol = not self.include_bg_pol
print(
f"External poloidal field (tokamak-like 1/R): {'ON' if self.include_bg_pol else 'OFF'} "
f"(Bpol0={self.Bpol0:.3g} T at R0={self.cfg.R0:g} m)"
)
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("bracketleft",):
self.Bext0 = float(self.Bext0 / 1.2)
if self.show_fieldlines and self.include_bg_tor:
self.update_solution()
return
elif key_sym in ("bracketright",):
self.Bext0 = float(self.Bext0 * 1.2)
if self.show_fieldlines and self.include_bg_tor:
self.update_solution()
return
elif key_sym in ("comma",):
self.Bpol0 = float(self.Bpol0 / 1.2)
if self.show_fieldlines and self.include_bg_pol:
self.update_solution()
return
elif key_sym in ("period",):
self.Bpol0 = float(self.Bpol0 * 1.2)
if self.show_fieldlines and self.include_bg_pol:
self.update_solution()
return
elif key_sym in ("r", "R"):
self.update_solution()
return
elif key_sym in ("e", "E"):
self._export_paraview()
elif key_sym in ("s", "S"):
self._save_screenshot()
self._update_text()
self.window.Render()
def _on_left_click(self, _obj, _evt) -> None:
x, y = self.interactor.GetEventPosition()
# 1) Try selecting an electrode.
self.prop_picker.Pick(x, y, 0, self.renderer)
actor = self.prop_picker.GetActor()
if actor in self._electrode_actor_to_index:
self.selected = int(self._electrode_actor_to_index[actor])
self.I_slider_rep.SetValue(float(self.currents_raw[self.selected]))
self._update_electrode_actors()
self._update_text()
self.window.Render()
return
# 2) Add/move electrode by picking torus surface.
if self.mode in ("add_source", "add_sink", "move"):
if not self.cell_picker.Pick(x, y, 0, self.renderer):
self.mode = "none"
self._update_text()
self.window.Render()
return
p = np.array(self.cell_picker.GetPickPosition(), dtype=float)
theta, phi = torus_angles_from_point(self.cfg.R0, p)
if self.mode == "move":
if self.selected is not None and self.active[self.selected] > 0:
self.theta_src[self.selected] = theta
self.phi_src[self.selected] = phi
self.mode = "none"
self.update_solution()
else:
self.mode = "none"
elif self.mode == "add_source":
self.mode = "none"
self._add_electrode(theta, phi, +self.cfg.current_default_A)
elif self.mode == "add_sink":
self.mode = "none"
self._add_electrode(theta, phi, -self.cfg.current_default_A)
return
# Fall back to default camera interaction.
self.interactor.GetInteractorStyle().OnLeftButtonDown()
def _save_screenshot(self) -> None:
outdir = Path("figures/gui_screenshots")
outdir.mkdir(parents=True, exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")
path = outdir / f"torus_cut_gui_{ts}.png"
w2i = vtk.vtkWindowToImageFilter()
w2i.SetInput(self.window)
w2i.Update()
writer = vtk.vtkPNGWriter()
writer.SetFileName(str(path))
writer.SetInputConnection(w2i.GetOutputPort())
writer.Write()
print(f"Saved screenshot: {path}")
def _export_paraview(self) -> None:
from .paraview import fieldlines_to_vtu, point_cloud_to_vtu, torus_surface_to_vtu, write_vtm, write_vtu
ts = time.strftime("%Y%m%d_%H%M%S")
outdir = Path("paraview") / f"gui_torus_cut_{ts}"
outdir.mkdir(parents=True, exist_ok=True)
V = self._cache.get("V")
s = self._cache.get("s")
Ktheta = self._cache.get("Ktheta")
Kphi = self._cache.get("Kphi")
Kmag = self._cache.get("Kmag")
if V is None or s is None or Ktheta is None or Kphi is None or Kmag is None:
print("ParaView export: no cached solution yet.")
return
e_theta = np.asarray(self._e_theta, dtype=float)
e_phi = np.asarray(self._e_phi, dtype=float)
K_vec = Ktheta[..., None] * e_theta + Kphi[..., None] * e_phi
surf = write_vtu(
outdir / "winding_surface.vtu",
torus_surface_to_vtu(
surface=self.surface,
point_data={
"V": V.reshape(-1),
"s": s.reshape(-1),
"K": K_vec.reshape(-1, 3),
"Ktheta": Ktheta.reshape(-1),
"Kphi": Kphi.reshape(-1),
"|K|": Kmag.reshape(-1),
"V_cut": np.full((V.size,), float(self.V_cut), dtype=float),
},
),
)
blocks: dict[str, str] = {"winding_surface": surf.name}
active = np.flatnonzero(self.active > 0.0)
if active.size > 0:
xyz = torus_xyz(self.cfg.R0, self.cfg.a, self.theta_src[active], self.phi_src[active])
I = (
np.asarray(self._Iproj_cache, dtype=float)[active]
if self._Iproj_cache is not None
else np.asarray(self.currents_raw, dtype=float)[active]
)
elec = write_vtu(
outdir / "electrodes.vtu",
point_cloud_to_vtu(
points=np.asarray(xyz, dtype=float),
point_data={"I_A": I, "sign_I": np.sign(I)},
),
)
blocks["electrodes"] = elec.name
# Cut ring (where the potential jump is placed), exported as a point cloud for reference.
theta0 = float(self.cfg.theta_cut) % (2 * np.pi)
phi_line = np.linspace(0.0, 2 * np.pi, 80, endpoint=False)
theta_line = theta0 * np.ones_like(phi_line)
cut_xyz = torus_xyz(self.cfg.R0, self.cfg.a, theta_line, phi_line)
cut = write_vtu(outdir / "cut_ring.vtu", point_cloud_to_vtu(points=cut_xyz))
blocks["cut_ring"] = cut.name
if self.show_fieldlines and self._traj_cache is not None:
traj_pv = np.transpose(self._traj_cache, (1, 0, 2))
fl = write_vtu(outdir / "fieldlines.vtu", fieldlines_to_vtu(traj=traj_pv))
blocks["fieldlines"] = fl.name
scene = write_vtm(outdir / "scene.vtm", blocks)
print(f"Saved ParaView scene: {scene}")
[docs]
def run(self) -> None:
print("Starting cut+electrodes GUI. Close the window to exit.")
self._update_text()
self.window.Render()
self.interactor.Initialize()
self.interactor.Start()
[docs]
def run_torus_cut_voltage_gui(
*, cfg: CutGUIConfig = CutGUIConfig(), initial_electrodes: dict | None = None
) -> None: # pragma: no cover
"""Entry point for the cut-voltage (+electrodes) GUI examples."""
_require_vtk()
print("Interactive torus cut+electrodes GUI")
print(f" R0={cfg.R0} a={cfg.a} n_theta={cfg.n_theta} n_phi={cfg.n_phi}")
print(
f" V_cut_default={cfg.V_cut_default} V_cut_slider_max={cfg.V_cut_slider_max} sigma_s={cfg.sigma_s}"
)
print(
f" electrodes: n_max={cfg.n_electrodes_max} I0={cfg.current_default_A} Imax={cfg.current_slider_max_A} "
f"sigma_theta={cfg.sigma_theta} sigma_phi={cfg.sigma_phi} cg_tol={cfg.cg_tol} cg_maxiter={cfg.cg_maxiter}"
)
print(
f" n_lines={cfg.n_fieldlines} steps={cfg.fieldline_steps} ds={cfg.fieldline_step_size_m} eps={cfg.biot_savart_eps}"
)
print(f" mu0={float(MU0):.6e}")
app = TorusCutVoltageGUI(cfg, initial_electrodes=initial_electrodes)
app.run()
[docs]
class TorusVmecBnOptimizeGUI: # pragma: no cover
"""VTK GUI: optimize electrode sources/sinks to reduce (B·n)/norm(B) on a target VMEC surface."""
def __init__(self, cfg: VmecOptGUIConfig):
_require_vtk()
jax.config.update("jax_enable_x64", True)
self.cfg = cfg
self.surface: TorusSurface = make_torus_surface(
R0=cfg.R0, a=cfg.a, n_theta=cfg.n_theta, n_phi=cfg.n_phi
)
# Target VMEC surface.
self._target_xyz_grid, self._target_points, self._target_normals, self._target_weights = (
self._build_target_surface()
)
# Background field control.
self.B0 = float(cfg.B0)
self.Bpol0 = float(cfg.Bpol0)
self.trace_include_bg_tor = bool(cfg.trace_include_bg_default_on and self.B0 != 0.0)
self.trace_include_bg_pol = bool(cfg.trace_include_bg_default_on and self.Bpol0 != 0.0)
# Scaling: I_phys = current_scale * currents_raw (then projected to net-zero over active electrodes).
self.auto_current_scale = cfg.current_scale is None
if cfg.current_scale is None:
self.current_scale = self._auto_current_scale(B0=self.B0, R0=cfg.R0)
else:
self.current_scale = float(cfg.current_scale)
# Optimization controls.
self.lr = float(cfg.lr)
self.reg_currents = float(cfg.reg_currents)
self.sigma_s = float(cfg.sigma_s)
self.steps_per_opt = int(cfg.steps_per_opt)
self.optimize_positions = bool(cfg.optimize_positions)
# Electrode state (fixed size to avoid recompilation when user adds/removes).
self.N = int(cfg.n_electrodes_max)
self.theta_src = np.zeros((self.N,), dtype=float)
self.phi_src = np.zeros((self.N,), dtype=float)
self.currents_raw = np.zeros((self.N,), dtype=float) # dimensionless
self.active = np.zeros((self.N,), dtype=float)
# Initialize a random set of active electrodes.
n0 = int(min(cfg.n_electrodes_init, cfg.n_electrodes_max))
rng = np.random.default_rng(0)
self.theta_src[:n0] = rng.uniform(0.0, 2 * np.pi, size=(n0,))
self.phi_src[:n0] = rng.uniform(0.0, 2 * np.pi, size=(n0,))
self.currents_raw[:n0] = float(cfg.init_current_raw_rms) * rng.standard_normal(size=(n0,))
self.active[:n0] = 1.0
self.selected: int | None = int(np.argmax(self.active)) if np.any(self.active) else None
self.mode: Literal["none", "add_source", "add_sink", "move"] = "none"
# Precompute unit vectors for K decomposition.
self._e_theta = self.surface.r_theta / self.surface.a
self._e_phi = self.surface.r_phi / jnp.sqrt(self.surface.G)[..., None]
# Field line seeds (inside the torus).
theta_seed = jnp.linspace(0.0, 2 * jnp.pi, cfg.n_fieldlines, endpoint=False)
rho = 0.5 * cfg.a
R = cfg.R0 + rho * jnp.cos(theta_seed)
Z = rho * jnp.sin(theta_seed)
self._seeds = jnp.stack([R, jnp.zeros_like(R), Z], axis=-1)
self.scalar_name: ScalarName = "|K|"
self.show_fieldlines = True
# Cached computed state (numpy).
self._cache: dict[str, np.ndarray] = {}
self._traj_cache: np.ndarray | None = None
self._Iproj_cache: np.ndarray | None = None
self._target_Bn_over_B_cache: np.ndarray | None = None
self._metrics_cache: dict[str, float] = {}
# Text entry ("textbox") state.
self._edit_mode: Literal[
"none",
"current",
"B0",
"Bpol0",
"current_scale",
"lr",
"steps_per_opt",
"reg_currents",
"sigma_s",
] = "none"
self._edit_buffer: str = ""
# Optimizer state.
self._opt = optax.adam(self.lr)
self._opt_state = self._opt.init(self._params_jax())
self._build_scene()
# Compile compute + optimization functions once.
self._compute_jit = jax.jit(self._compute_state)
self._opt_step_jit = jax.jit(self._opt_step)
print("Compiling JAX pipelines (first update may take a moment)...")
self.update_solution()
@staticmethod
def _auto_current_scale(*, B0: float, R0: float) -> float:
mu0 = float(4e-7 * np.pi)
if float(B0) == 0.0:
return 1.0
return float(2 * np.pi * R0 * abs(B0) / mu0)
def _build_target_surface(self) -> tuple[np.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
cfg = self.cfg
vmec_path = _resolve_vmec_input_path(cfg.vmec_input)
boundary = read_vmec_boundary(str(vmec_path))
print("VMEC target surface:")
print(f" file={vmec_path} NFP={boundary.nfp} nmodes={boundary.m.size}")
target = vmec_target_surface(
boundary,
torus_R0=float(cfg.R0),
torus_a=float(cfg.a),
fit_margin=float(cfg.fit_margin),
n_theta=int(cfg.surf_n_theta),
n_phi=int(cfg.surf_n_phi),
dtype=jnp.float64,
)
print(" fit into circular torus:")
print(f" torus: R0={cfg.R0} a={cfg.a} fit_margin={cfg.fit_margin}")
print(
" shift_R={:+.6e} m scale={:.6e} rho_max(before)={:.6e} m".format(
float(target.fit.shift_R), float(target.fit.scale_rho), float(target.fit.rho_max_before_m)
)
)
xyz = target.xyz
n_hat = target.normals
w_area = target.weights
xyz_grid_np = np.asarray(xyz, dtype=float)
points = xyz.reshape((-1, 3))
normals = n_hat.reshape((-1, 3))
weights = w_area.reshape((-1,))
return xyz_grid_np, points, normals, weights
def _params_jax(self) -> SourceParams:
return SourceParams(
theta_src=jnp.asarray(self.theta_src, dtype=jnp.float64),
phi_src=jnp.asarray(self.phi_src, dtype=jnp.float64),
currents_raw=jnp.asarray(self.currents_raw, dtype=jnp.float64),
)
def _compute_state(
self,
theta_src: jnp.ndarray,
phi_src: jnp.ndarray,
currents_raw: jnp.ndarray,
active: jnp.ndarray,
B0: jnp.ndarray,
Bpol0: jnp.ndarray,
current_scale: jnp.ndarray,
sigma_s: jnp.ndarray,
reg_currents: jnp.ndarray,
trace_include_bg_tor: jnp.ndarray,
trace_include_bg_pol: jnp.ndarray,
compute_traj: jnp.ndarray,
):
# Project currents to net-zero over active electrodes (in raw units).
mask = active
Iraw = currents_raw * mask
n = jnp.sum(mask)
mean = jnp.where(n > 0, jnp.sum(Iraw) / n, 0.0)
Iraw = Iraw - mean * mask
I = current_scale * Iraw # physical currents [A], projected
s = deposit_current_sources(
self.surface,
theta_src=theta_src,
phi_src=phi_src,
currents=I,
sigma_theta=self.cfg.sigma_theta,
sigma_phi=self.cfg.sigma_phi,
)
V, _ = solve_current_potential(
self.surface,
s / (sigma_s + 1e-30),
tol=self.cfg.cg_tol,
maxiter=self.cfg.cg_maxiter,
use_preconditioner=self.cfg.use_preconditioner,
)
K = surface_current_from_potential(self.surface, V, sigma_s=sigma_s)
Kmag = jnp.linalg.norm(K, axis=-1)
Ktheta = jnp.sum(K * self._e_theta, axis=-1)
Kphi = jnp.sum(K * self._e_phi, axis=-1)
B_shell = biot_savart_surface(self.surface, K, self._target_points, eps=self.cfg.biot_savart_eps)
B_bg = tokamak_like_field(self._target_points, B_tor0=B0, B_pol0=Bpol0, R0=self.cfg.R0)
B_tot = B_bg + B_shell
Bn = jnp.sum(B_tot * self._target_normals, axis=-1)
Bmag = jnp.linalg.norm(B_tot, axis=-1)
Bn_over_B = Bn / (Bmag + 1e-30)
w = self._target_weights
wsum = jnp.sum(w)
# Use a weighted p-norm objective to more strongly penalize localized peaks in Bn/B.
p = float(self.cfg.bn_p)
abs_bn = jnp.abs(Bn_over_B)
mean_p = jnp.sum(w * (abs_bn**p)) / (wsum + 1e-30)
loss_bn = mean_p ** (2.0 / p)
loss_reg = reg_currents * jnp.mean(Iraw * Iraw)
loss = loss_bn + loss_reg
Bn_over_B_rms = jnp.sqrt(jnp.sum(w * (Bn_over_B * Bn_over_B)) / (wsum + 1e-30))
Bn_over_B_max = jnp.max(jnp.abs(Bn_over_B))
I_rms = jnp.sqrt(jnp.mean(I * I))
def B_fn(xyz: jnp.ndarray) -> jnp.ndarray:
B = biot_savart_surface(self.surface, K, xyz, eps=self.cfg.biot_savart_eps)
Btor = jnp.asarray(trace_include_bg_tor, dtype=B.dtype) * jnp.asarray(B0, dtype=B.dtype)
Bpol = jnp.asarray(trace_include_bg_pol, dtype=B.dtype) * jnp.asarray(Bpol0, dtype=B.dtype)
B = B + tokamak_like_field(xyz, B_tor0=Btor, B_pol0=Bpol, R0=self.cfg.R0)
return B
n_steps = self.cfg.fieldline_steps
n_lines = self.cfg.n_fieldlines
def do_trace(_):
return trace_field_lines_batch(
B_fn,
self._seeds,
step_size=self.cfg.fieldline_step_size_m,
n_steps=n_steps,
normalize=True,
)
def no_trace(_):
return jnp.zeros((n_steps + 1, n_lines, 3), dtype=jnp.float64)
traj = jax.lax.cond(compute_traj, do_trace, no_trace, operand=None)
return (
V,
s,
Kmag,
Ktheta,
Kphi,
Bn_over_B,
loss,
loss_bn,
loss_reg,
Bn_over_B_rms,
Bn_over_B_max,
I_rms,
traj,
I,
)
def _loss_fn(
self,
params: SourceParams,
*,
active: jnp.ndarray,
B0: jnp.ndarray,
Bpol0: jnp.ndarray,
current_scale: jnp.ndarray,
sigma_s: jnp.ndarray,
reg_currents: jnp.ndarray,
):
(
_V,
_s,
_Kmag,
_Ktheta,
_Kphi,
_Bn_over_B,
loss,
loss_bn,
loss_reg,
Bn_over_B_rms,
Bn_over_B_max,
I_rms,
_traj,
_Iproj,
) = (
self._compute_state(
params.theta_src,
params.phi_src,
params.currents_raw,
active,
B0,
Bpol0,
current_scale,
sigma_s,
reg_currents,
jnp.asarray(True),
jnp.asarray(True),
jnp.asarray(False),
)
)
aux = {
"loss": loss,
"loss_bn": loss_bn,
"loss_reg": loss_reg,
"Bn_over_B_rms": Bn_over_B_rms,
"Bn_over_B_max": Bn_over_B_max,
"I_rms_A": I_rms,
}
return loss, aux
def _opt_step(
self,
params: SourceParams,
opt_state,
*,
active: jnp.ndarray,
B0: jnp.ndarray,
Bpol0: jnp.ndarray,
current_scale: jnp.ndarray,
sigma_s: jnp.ndarray,
reg_currents: jnp.ndarray,
optimize_positions: jnp.ndarray,
):
(loss, aux), g = jax.value_and_grad(self._loss_fn, has_aux=True)(
params,
active=active,
B0=B0,
Bpol0=Bpol0,
current_scale=current_scale,
sigma_s=sigma_s,
reg_currents=reg_currents,
)
theta_g = jnp.where(optimize_positions, g.theta_src, jnp.zeros_like(g.theta_src))
phi_g = jnp.where(optimize_positions, g.phi_src, jnp.zeros_like(g.phi_src))
g = SourceParams(theta_src=theta_g, phi_src=phi_g, currents_raw=g.currents_raw)
updates, opt_state2 = self._opt.update(g, opt_state, params)
p2 = optax.apply_updates(params, updates)
twopi = 2.0 * jnp.pi
p2 = SourceParams(
theta_src=jnp.mod(p2.theta_src, twopi),
phi_src=jnp.mod(p2.phi_src, twopi),
currents_raw=p2.currents_raw,
)
return p2, opt_state2, aux
def _build_scene(self) -> None:
cfg = self.cfg
self.renderer = vtk.vtkRenderer()
self.renderer.SetBackground(1.0, 1.0, 1.0)
self.window = vtk.vtkRenderWindow()
self.window.AddRenderer(self.renderer)
self.window.SetSize(*cfg.window_size)
self.interactor = vtk.vtkRenderWindowInteractor()
self.interactor.SetRenderWindow(self.window)
style = vtk.vtkInteractorStyleTrackballCamera()
self.interactor.SetInteractorStyle(style)
# Torus surface actor (with scalars updated each solve).
xyz = np.asarray(self.surface.r)
self.torus_poly = _build_torus_polydata(xyz)
self.surface_scalars_np = np.zeros((cfg.n_theta * cfg.n_phi,), dtype=np.float32)
self.surface_scalars_vtk = numpy_to_vtk(self.surface_scalars_np, deep=False)
self.surface_scalars_vtk.SetName("scalar")
self.torus_poly.GetPointData().SetScalars(self.surface_scalars_vtk)
self.lut = vtk.vtkLookupTable()
self.lut.SetNumberOfTableValues(256)
self.lut.Build()
self.torus_mapper = vtk.vtkPolyDataMapper()
self.torus_mapper.SetInputData(self.torus_poly)
self.torus_mapper.SetLookupTable(self.lut)
self.torus_mapper.SetScalarModeToUsePointData()
self.torus_mapper.ScalarVisibilityOn()
self.torus_actor = vtk.vtkActor()
self.torus_actor.SetMapper(self.torus_mapper)
self.torus_actor.GetProperty().SetOpacity(cfg.surface_opacity)
self.torus_actor.GetProperty().SetInterpolationToPhong()
self.torus_actor.GetProperty().SetSpecular(0.2)
self.torus_actor.GetProperty().SetSpecularPower(30.0)
self.renderer.AddActor(self.torus_actor)
# Target VMEC surface actor (colored by (B·n)/|B|).
self.target_poly = _build_torus_polydata(self._target_xyz_grid)
self.target_scalars_np = np.zeros((cfg.surf_n_theta * cfg.surf_n_phi,), dtype=np.float32)
self.target_scalars_vtk = numpy_to_vtk(self.target_scalars_np, deep=False)
self.target_scalars_vtk.SetName("Bn_over_B")
self.target_poly.GetPointData().SetScalars(self.target_scalars_vtk)
self.target_mapper = vtk.vtkPolyDataMapper()
self.target_mapper.SetInputData(self.target_poly)
self.target_mapper.SetLookupTable(self.lut)
self.target_mapper.SetScalarModeToUsePointData()
self.target_mapper.ScalarVisibilityOn()
self.target_actor = vtk.vtkActor()
self.target_actor.SetMapper(self.target_mapper)
self.target_actor.GetProperty().SetOpacity(cfg.target_opacity)
self.target_actor.GetProperty().SetInterpolationToPhong()
self.target_actor.GetProperty().SetSpecular(0.1)
self.target_actor.GetProperty().SetSpecularPower(20.0)
self.renderer.AddActor(self.target_actor)
# Axis curve for reference.
axis_phi = np.linspace(0, 2 * np.pi, 200, endpoint=True)
axis = np.stack([cfg.R0 * np.cos(axis_phi), cfg.R0 * np.sin(axis_phi), 0.0 * axis_phi], axis=-1)
axis_poly = vtk.vtkPolyData()
axis_points = vtk.vtkPoints()
axis_points.SetData(numpy_to_vtk(axis, deep=True))
axis_poly.SetPoints(axis_points)
axis_lines = vtk.vtkCellArray()
pl = vtk.vtkPolyLine()
pl.GetPointIds().SetNumberOfIds(axis.shape[0])
for i in range(axis.shape[0]):
pl.GetPointIds().SetId(i, i)
axis_lines.InsertNextCell(pl)
axis_poly.SetLines(axis_lines)
axis_mapper = vtk.vtkPolyDataMapper()
axis_mapper.SetInputData(axis_poly)
axis_actor = vtk.vtkActor()
axis_actor.SetMapper(axis_mapper)
axis_actor.GetProperty().SetColor(0.0, 0.0, 0.0)
axis_actor.GetProperty().SetLineWidth(2.0)
self.renderer.AddActor(axis_actor)
# Field lines actor (updated each solve).
n_pts_line = self.cfg.fieldline_steps + 1
self.field_poly = _build_fieldlines_polydata(self.cfg.n_fieldlines, n_pts_line)
self.field_points_np = np.zeros((self.cfg.n_fieldlines * n_pts_line, 3), dtype=np.float32)
self.field_points_vtk = numpy_to_vtk(self.field_points_np, deep=False)
self.field_points_vtk.SetName("field_points")
self.field_poly.GetPoints().SetData(self.field_points_vtk)
self.field_mapper = vtk.vtkPolyDataMapper()
self.field_mapper.SetInputData(self.field_poly)
self.field_actor = vtk.vtkActor()
self.field_actor.SetMapper(self.field_mapper)
self.field_actor.GetProperty().SetColor(0.1, 0.25, 0.9)
self.field_actor.GetProperty().SetLineWidth(2.0)
self.renderer.AddActor(self.field_actor)
# Electrode actors (spheres).
self.electrode_actors = []
self._electrode_actor_to_index = {}
for i in range(self.N):
src = vtk.vtkSphereSource()
src.SetThetaResolution(16)
src.SetPhiResolution(16)
src.SetRadius(0.05 * self.cfg.a)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(src.GetOutputPort())
actor = vtk.vtkActor()
actor.SetMapper(mapper)
actor.SetVisibility(False)
self.renderer.AddActor(actor)
self.electrode_actors.append((src, actor))
self._electrode_actor_to_index[actor] = i
# Text overlay (help + status).
self.text = vtk.vtkTextActor()
_setup_text_actor(self.text, x=0.01, y=0.99, font_size=16, top=True)
self.renderer.AddActor2D(self.text)
# Editable numeric "textbox".
self.input_text = vtk.vtkTextActor()
_setup_text_actor(self.input_text, x=0.01, y=0.01, font_size=16, top=False)
self.renderer.AddActor2D(self.input_text)
# Current slider for the selected electrode (physical A).
rep = vtk.vtkSliderRepresentation2D()
rep.SetMinimumValue(-self.cfg.current_slider_max_A)
rep.SetMaximumValue(+self.cfg.current_slider_max_A)
rep.SetValue(0.0)
rep.SetTitleText("Selected electrode current I [A]")
rep.SetLabelFormat("%0.0f")
rep.SetSliderLength(0.02)
rep.SetSliderWidth(0.03)
rep.SetTubeWidth(0.006)
rep.SetEndCapLength(0.01)
rep.SetEndCapWidth(0.03)
rep.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep.GetPoint1Coordinate().SetValue(0.10, 0.06)
rep.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay()
rep.GetPoint2Coordinate().SetValue(0.55, 0.06)
self.slider_rep = rep
self.slider = vtk.vtkSliderWidget()
self.slider.SetInteractor(self.interactor)
self.slider.SetRepresentation(rep)
self.slider.EnabledOn()
def on_slider_end(_obj, _evt):
if self.selected is None:
return
val_A = float(self.slider_rep.GetValue())
if float(self.current_scale) == 0.0:
return
self.currents_raw[self.selected] = val_A / float(self.current_scale)
self.update_solution()
self.slider.AddObserver(vtk.vtkCommand.EndInteractionEvent, on_slider_end)
# Picking helpers.
self.cell_picker = vtk.vtkCellPicker()
self.cell_picker.SetTolerance(0.0005)
self.cell_picker.PickFromListOn()
self.cell_picker.AddPickList(self.torus_actor)
self.prop_picker = vtk.vtkPropPicker()
# Interactor events.
self.interactor.AddObserver(vtk.vtkCommand.KeyPressEvent, self._on_keypress)
self.interactor.AddObserver(vtk.vtkCommand.LeftButtonPressEvent, self._on_left_click)
# Initial camera.
self.renderer.ResetCamera()
def _help_text(self) -> str:
return (
"VMEC (B·n)/|B| optimization GUI (VTK)\n"
"Goal: drive target-surface (B·n)/|B| -> 0 by moving/setting electrode sources/sinks\n"
"Mouse: rotate/zoom as usual\n"
"Click electrode: select\n"
"Keys:\n"
" o: optimize (run N steps)\n"
" space: single optimization step\n"
" p: toggle optimize positions\n"
" a: add SOURCE (next click on torus)\n"
" z: add SINK (next click on torus)\n"
" m: move selected (next click on torus)\n"
" d: delete selected\n"
" tab: cycle selected\n"
" c: cycle torus scalar (|K|, V, s, Kθ, Kφ)\n"
" f: toggle field lines\n"
" t: toggle toroidal background in field lines\n"
" y: toggle poloidal background in field lines\n"
" [/]: decrease/increase B0 ,/. : decrease/increase Bpol0\n"
" r: recompute\n"
" e: export ParaView (.vtu/.vtm)\n"
" i (or v): type selected electrode current [A]\n"
" b: type background B0 [T]\n"
" u: type background Bpol0 [T]\n"
" k: type current_scale [A/unit]\n"
" l: type learning rate\n"
" n: type opt steps per 'o'\n"
" g: type reg_currents\n"
" x: type sigma_s\n"
" s: save screenshot\n"
)
def _status_text(self) -> str:
n_active = int(np.sum(self.active))
sel = self.selected
if sel is None:
sel_txt = "none"
else:
Iraw_A = float(self.currents_raw[sel] * float(self.current_scale))
Iproj = None
if self._Iproj_cache is not None and sel < self._Iproj_cache.size:
Iproj = float(self._Iproj_cache[sel])
if Iproj is None:
sel_txt = f"{sel} Iraw={Iraw_A:+.3e} A (active={self.active[sel]:.0f})"
else:
sel_txt = f"{sel} Iraw={Iraw_A:+.3e} A Iproj={Iproj:+.3e} A (active={self.active[sel]:.0f})"
m = self._metrics_cache
metric_lines = ""
if m:
metric_lines = (
f"loss={m.get('loss', np.nan):.3e} bn_obj={m.get('loss_bn', np.nan):.3e} reg={m.get('loss_reg', np.nan):.3e}\n"
f"rms(Bn/B)={m.get('Bn_over_B_rms', np.nan):.3e} max|Bn/B|={m.get('Bn_over_B_max', np.nan):.3e} I_rms={m.get('I_rms_A', np.nan):.3e} A\n"
)
auto_txt = "auto" if self.auto_current_scale else "manual"
trace_tor = "ON" if self.trace_include_bg_tor else "OFF"
trace_pol = "ON" if self.trace_include_bg_pol else "OFF"
return (
f"Active electrodes: {n_active}/{self.N} optimize_positions={self.optimize_positions}\n"
f"Selected: {sel_txt}\n"
f"Mode: {self.mode}\n"
f"Torus scalar: {self.scalar_name} Target scalar: (B·n)/|B|\n"
f"B0={self.B0:.6g} T Bpol0={self.Bpol0:.6g} T trace_tor={trace_tor} trace_pol={trace_pol}\n"
f"current_scale={self.current_scale:.3e} A/unit ({auto_txt})\n"
f"lr={self.lr:.3e} steps_per_opt={self.steps_per_opt} bn_p={int(self.cfg.bn_p)} reg_currents={self.reg_currents:.3e} sigma_s={self.sigma_s:.3e}\n"
+ metric_lines
)
def _update_text(self) -> None:
self.text.SetInput(self._help_text() + "\n" + self._status_text())
if self._edit_mode == "none":
self.input_text.SetInput(
"Type: i/v=current [A], b=B0, u=Bpol0, k=current_scale, l=lr, n=steps, g=reg, x=sigma_s"
)
else:
self.input_text.SetInput(
f"Input [{self._edit_mode}]: {self._edit_buffer} (Enter=apply, Esc=cancel)"
)
def _apply_surface_scalar(self) -> None:
if not self._cache:
return
if self.scalar_name == "|K|":
scal = self._cache["Kmag"].reshape((-1,))
elif self.scalar_name == "V":
scal = self._cache["V"].reshape((-1,))
elif self.scalar_name == "s":
scal = self._cache["s"].reshape((-1,))
elif self.scalar_name == "K_theta":
scal = self._cache["Ktheta"].reshape((-1,))
elif self.scalar_name == "K_phi":
scal = self._cache["Kphi"].reshape((-1,))
else:
raise ValueError(self.scalar_name)
self.surface_scalars_np[:] = scal.astype(np.float32, copy=False)
self.surface_scalars_vtk.Modified()
self.torus_poly.Modified()
smin = float(np.nanmin(scal))
smax = float(np.nanmax(scal))
if not np.isfinite(smin) or not np.isfinite(smax) or smin == smax:
smin, smax = 0.0, 1.0
if self.scalar_name in ("s", "K_theta", "K_phi", "V"):
vmax = max(abs(smin), abs(smax))
smin, smax = -vmax, vmax
self.torus_mapper.SetScalarRange(smin, smax)
def _apply_target_scalar(self) -> None:
if self._target_Bn_over_B_cache is None:
return
scal = self._target_Bn_over_B_cache.reshape((-1,))
self.target_scalars_np[:] = scal.astype(np.float32, copy=False)
self.target_scalars_vtk.Modified()
self.target_poly.Modified()
smin = float(np.nanmin(scal))
smax = float(np.nanmax(scal))
if not np.isfinite(smin) or not np.isfinite(smax) or smin == smax:
smin, smax = 0.0, 1.0
vmax = max(abs(smin), abs(smax))
self.target_mapper.SetScalarRange(-vmax, vmax)
def _update_electrode_actors(self) -> None:
Iproj = self._Iproj_cache
if Iproj is None:
Iproj = np.zeros_like(self.currents_raw)
for i in range(self.N):
src, actor = self.electrode_actors[i]
if self.active[i] <= 0.0:
actor.SetVisibility(False)
continue
actor.SetVisibility(True)
p = torus_xyz(self.cfg.R0, self.cfg.a, self.theta_src[i], self.phi_src[i])
src.SetCenter(float(p[0]), float(p[1]), float(p[2]))
I = float(Iproj[i])
if i == self.selected:
actor.GetProperty().SetColor(1.0, 0.8, 0.2)
else:
if I > 0:
actor.GetProperty().SetColor(0.85, 0.15, 0.15)
elif I < 0:
actor.GetProperty().SetColor(0.15, 0.25, 0.85)
else:
actor.GetProperty().SetColor(0.4, 0.4, 0.4)
r0 = 0.04 * self.cfg.a
r = r0 * (0.6 + 0.8 * min(abs(I) / self.cfg.current_slider_max_A, 1.0))
src.SetRadius(float(r))
src.Modified()
[docs]
def update_solution(self) -> None:
t0 = time.perf_counter()
self._update_text()
self.window.Render()
th = jnp.asarray(self.theta_src, dtype=jnp.float64)
ph = jnp.asarray(self.phi_src, dtype=jnp.float64)
Iraw = jnp.asarray(self.currents_raw, dtype=jnp.float64)
act = jnp.asarray(self.active, dtype=jnp.float64)
(
V,
s,
Kmag,
Ktheta,
Kphi,
Bn_over_B,
loss,
loss_bn,
loss_reg,
Bn_over_B_rms,
Bn_over_B_max,
I_rms,
traj,
Iproj,
) = self._compute_jit(
th,
ph,
Iraw,
act,
jnp.asarray(self.B0, dtype=jnp.float64),
jnp.asarray(self.Bpol0, dtype=jnp.float64),
jnp.asarray(self.current_scale, dtype=jnp.float64),
jnp.asarray(self.sigma_s, dtype=jnp.float64),
jnp.asarray(self.reg_currents, dtype=jnp.float64),
jnp.asarray(self.trace_include_bg_tor),
jnp.asarray(self.trace_include_bg_pol),
jnp.asarray(self.show_fieldlines),
)
V.block_until_ready()
t1 = time.perf_counter()
self._cache = {
"V": np.asarray(V, dtype=np.float32),
"s": np.asarray(s, dtype=np.float32),
"Kmag": np.asarray(Kmag, dtype=np.float32),
"Ktheta": np.asarray(Ktheta, dtype=np.float32),
"Kphi": np.asarray(Kphi, dtype=np.float32),
}
self._traj_cache = np.asarray(traj, dtype=np.float32)
self._Iproj_cache = np.asarray(Iproj, dtype=float)
self._target_Bn_over_B_cache = np.asarray(Bn_over_B, dtype=np.float32)
self._metrics_cache = {
"loss": float(loss),
"loss_bn": float(loss_bn),
"loss_reg": float(loss_reg),
"Bn_over_B_rms": float(Bn_over_B_rms),
"Bn_over_B_max": float(Bn_over_B_max),
"I_rms_A": float(I_rms),
}
self._apply_surface_scalar()
self._apply_target_scalar()
self._update_electrode_actors()
# Update field lines geometry.
self.field_actor.SetVisibility(bool(self.show_fieldlines))
if self.show_fieldlines and self._traj_cache is not None:
self.field_points_np[:] = self._traj_cache.reshape((-1, 3))
self.field_points_vtk.Modified()
self.field_poly.Modified()
if self.selected is not None:
self.slider_rep.SetValue(float(self.currents_raw[self.selected] * float(self.current_scale)))
self._update_text()
self.window.Render()
t2 = time.perf_counter()
print(
"update: solve+trace {:.3f}s, total {:.3f}s (scalar={}, active={})".format(
t1 - t0, t2 - t0, self.scalar_name, int(np.sum(self.active))
)
)
def _select_next(self) -> None:
active_idx = np.flatnonzero(self.active > 0.0)
if active_idx.size == 0:
self.selected = None
return
if self.selected is None or self.selected not in active_idx:
self.selected = int(active_idx[0])
return
k = int(np.where(active_idx == self.selected)[0][0])
self.selected = int(active_idx[(k + 1) % active_idx.size])
def _delete_selected(self) -> None:
if self.selected is None:
return
i = self.selected
self.active[i] = 0.0
self.currents_raw[i] = 0.0
self.theta_src[i] = 0.0
self.phi_src[i] = 0.0
self._select_next()
self.update_solution()
def _add_electrode(self, theta: float, phi: float, current_A: float) -> None:
free = np.flatnonzero(self.active <= 0.0)
if free.size == 0:
print("No free electrode slots; increase n_electrodes_max.")
return
i = int(free[0])
self.theta_src[i] = float(theta)
self.phi_src[i] = float(phi)
if float(self.current_scale) != 0.0:
self.currents_raw[i] = float(current_A) / float(self.current_scale)
else:
self.currents_raw[i] = 0.0
self.active[i] = 1.0
self.selected = i
self.slider_rep.SetValue(float(current_A))
self.update_solution()
def _begin_edit(self, mode: str) -> None:
self._edit_mode = mode # type: ignore[assignment]
if mode == "current":
if self.selected is None:
self._edit_mode = "none"
self._edit_buffer = ""
return
I_A = float(self.currents_raw[self.selected] * float(self.current_scale))
self._edit_buffer = f"{I_A:.6g}"
elif mode == "B0":
self._edit_buffer = f"{self.B0:.6g}"
elif mode == "current_scale":
self._edit_buffer = f"{self.current_scale:.6g}"
elif mode == "lr":
self._edit_buffer = f"{self.lr:.6g}"
elif mode == "steps_per_opt":
self._edit_buffer = f"{self.steps_per_opt:d}"
elif mode == "reg_currents":
self._edit_buffer = f"{self.reg_currents:.6g}"
elif mode == "sigma_s":
self._edit_buffer = f"{self.sigma_s:.6g}"
else:
self._edit_mode = "none"
self._edit_buffer = ""
return
self._update_text()
self.window.Render()
def _handle_edit_key(self, key_sym: str, key_code: str) -> bool:
if self._edit_mode == "none":
return False
if key_sym in ("Escape",):
self._edit_mode = "none"
self._edit_buffer = ""
self._update_text()
self.window.Render()
return True
if key_sym in ("Return", "KP_Enter"):
mode = self._edit_mode
txt = self._edit_buffer.strip()
try:
if mode == "steps_per_opt":
val = int(float(txt))
else:
val = float(txt)
except Exception:
print(f"Could not parse number: {self._edit_buffer!r}")
return True
if mode == "current" and self.selected is not None:
if float(self.current_scale) != 0.0:
self.currents_raw[self.selected] = float(val) / float(self.current_scale)
self.slider_rep.SetValue(float(val))
elif mode == "B0":
self.B0 = float(val)
if self.auto_current_scale:
self.current_scale = self._auto_current_scale(B0=self.B0, R0=self.cfg.R0)
elif mode == "Bpol0":
self.Bpol0 = float(val)
elif mode == "current_scale":
self.current_scale = float(val)
self.auto_current_scale = False
elif mode == "lr":
self.lr = float(val)
self._opt = optax.adam(self.lr)
self._opt_state = self._opt.init(self._params_jax())
self._opt_step_jit = jax.jit(self._opt_step)
elif mode == "steps_per_opt":
self.steps_per_opt = max(1, int(val))
elif mode == "reg_currents":
self.reg_currents = float(val)
elif mode == "sigma_s":
self.sigma_s = float(val)
self._edit_mode = "none"
self._edit_buffer = ""
self.update_solution()
return True
if key_sym in ("BackSpace", "Delete"):
self._edit_buffer = self._edit_buffer[:-1]
self._update_text()
self.window.Render()
return True
if key_code and key_code in "0123456789+-eE.":
self._edit_buffer = self._edit_buffer + key_code
self._update_text()
self.window.Render()
return True
return True
def _optimize(self, n_steps: int) -> None:
if n_steps <= 0:
return
params = self._params_jax()
act = jnp.asarray(self.active, dtype=jnp.float64)
B0 = jnp.asarray(self.B0, dtype=jnp.float64)
Bpol0 = jnp.asarray(self.Bpol0, dtype=jnp.float64)
current_scale = jnp.asarray(self.current_scale, dtype=jnp.float64)
sigma_s = jnp.asarray(self.sigma_s, dtype=jnp.float64)
reg = jnp.asarray(self.reg_currents, dtype=jnp.float64)
opt_pos = jnp.asarray(self.optimize_positions)
t0 = time.perf_counter()
aux_last = None
for _k in range(int(n_steps)):
params, self._opt_state, aux = self._opt_step_jit(
params,
self._opt_state,
active=act,
B0=B0,
Bpol0=Bpol0,
current_scale=current_scale,
sigma_s=sigma_s,
reg_currents=reg,
optimize_positions=opt_pos,
)
aux_last = aux
params.theta_src.block_until_ready()
t1 = time.perf_counter()
self.theta_src[:] = np.asarray(params.theta_src)
self.phi_src[:] = np.asarray(params.phi_src)
self.currents_raw[:] = np.asarray(params.currents_raw)
if aux_last is not None:
print(
"optimize: steps={} loss={:.3e} rms(Bn/B)={:.3e} max|Bn/B|={:.3e} I_rms={:.3e}A wall={:.3f}s".format(
int(n_steps),
float(aux_last["loss"]),
float(aux_last["Bn_over_B_rms"]),
float(aux_last["Bn_over_B_max"]),
float(aux_last["I_rms_A"]),
t1 - t0,
)
)
self.update_solution()
def _on_keypress(self, _obj, _evt) -> None:
key_sym = self.interactor.GetKeySym()
key_code = self.interactor.GetKeyCode()
if self._handle_edit_key(key_sym, key_code):
return
if key_sym in ("i", "I", "v", "V"):
self._begin_edit("current")
return
if key_sym in ("b", "B"):
self._begin_edit("B0")
return
if key_sym in ("u", "U"):
self._begin_edit("Bpol0")
return
if key_sym in ("k", "K"):
self._begin_edit("current_scale")
return
if key_sym in ("l", "L"):
self._begin_edit("lr")
return
if key_sym in ("n", "N"):
self._begin_edit("steps_per_opt")
return
if key_sym in ("g", "G"):
self._begin_edit("reg_currents")
return
if key_sym in ("x", "X"):
self._begin_edit("sigma_s")
return
if key_sym in ("space",):
self._optimize(1)
return
if key_sym in ("o", "O"):
self._optimize(self.steps_per_opt)
return
if key_sym in ("p", "P"):
self.optimize_positions = not self.optimize_positions
if key_sym in ("a", "A"):
self.mode = "add_source"
elif key_sym in ("z", "Z"):
self.mode = "add_sink"
elif key_sym in ("m", "M"):
self.mode = "move"
elif key_sym in ("d", "Delete", "BackSpace"):
self._delete_selected()
return
elif key_sym in ("Tab",):
self._select_next()
self._update_electrode_actors()
elif key_sym in ("c", "C"):
order: list[ScalarName] = ["|K|", "V", "s", "K_theta", "K_phi"]
k = order.index(self.scalar_name)
self.scalar_name = order[(k + 1) % len(order)]
self._apply_surface_scalar()
elif key_sym in ("f", "F"):
self.show_fieldlines = not self.show_fieldlines
if self.show_fieldlines:
self.update_solution()
return
self.field_actor.SetVisibility(False)
elif key_sym in ("t", "T"):
self.trace_include_bg_tor = not self.trace_include_bg_tor
print(
f"Fieldline tracing background (toroidal): {'ON' if self.trace_include_bg_tor else 'OFF'} "
f"(B0={self.B0:.6g} T)"
)
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("y", "Y"):
self.trace_include_bg_pol = not self.trace_include_bg_pol
print(
f"Fieldline tracing background (poloidal): {'ON' if self.trace_include_bg_pol else 'OFF'} "
f"(Bpol0={self.Bpol0:.6g} T)"
)
if self.show_fieldlines:
self.update_solution()
return
elif key_sym in ("bracketleft",):
self.B0 = float(self.B0 / 1.2)
if self.auto_current_scale:
self.current_scale = self._auto_current_scale(B0=self.B0, R0=self.cfg.R0)
self.update_solution()
return
elif key_sym in ("bracketright",):
self.B0 = float(self.B0 * 1.2)
if self.auto_current_scale:
self.current_scale = self._auto_current_scale(B0=self.B0, R0=self.cfg.R0)
self.update_solution()
return
elif key_sym in ("comma",):
self.Bpol0 = float(self.Bpol0 / 1.2)
self.update_solution()
return
elif key_sym in ("period",):
self.Bpol0 = float(self.Bpol0 * 1.2)
self.update_solution()
return
elif key_sym in ("r", "R"):
self.update_solution()
return
elif key_sym in ("e", "E"):
self._export_paraview()
elif key_sym in ("s", "S"):
self._save_screenshot()
self._update_text()
self.window.Render()
def _on_left_click(self, _obj, _evt) -> None:
x, y = self.interactor.GetEventPosition()
# 1) Try selecting an electrode.
self.prop_picker.Pick(x, y, 0, self.renderer)
actor = self.prop_picker.GetActor()
if actor in self._electrode_actor_to_index:
self.selected = int(self._electrode_actor_to_index[actor])
self.slider_rep.SetValue(float(self.currents_raw[self.selected] * float(self.current_scale)))
self._update_electrode_actors()
self._update_text()
self.window.Render()
return
# 2) Add/move electrode by picking the torus surface.
if self.mode in ("add_source", "add_sink", "move"):
if not self.cell_picker.Pick(x, y, 0, self.renderer):
self.mode = "none"
self._update_text()
self.window.Render()
return
p = np.array(self.cell_picker.GetPickPosition(), dtype=float)
theta, phi = torus_angles_from_point(self.cfg.R0, p)
if self.mode == "move":
if self.selected is not None and self.active[self.selected] > 0:
self.theta_src[self.selected] = theta
self.phi_src[self.selected] = phi
self.mode = "none"
self.update_solution()
else:
self.mode = "none"
elif self.mode == "add_source":
self.mode = "none"
self._add_electrode(theta, phi, +self.cfg.current_default_A)
elif self.mode == "add_sink":
self.mode = "none"
self._add_electrode(theta, phi, -self.cfg.current_default_A)
return
self.interactor.GetInteractorStyle().OnLeftButtonDown()
def _save_screenshot(self) -> None:
outdir = Path("figures/gui_screenshots")
outdir.mkdir(parents=True, exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")
path = outdir / f"torus_vmec_opt_gui_{ts}.png"
w2i = vtk.vtkWindowToImageFilter()
w2i.SetInput(self.window)
w2i.Update()
writer = vtk.vtkPNGWriter()
writer.SetFileName(str(path))
writer.SetInputConnection(w2i.GetOutputPort())
writer.Write()
print(f"Saved screenshot: {path}")
def _export_paraview(self) -> None:
from .paraview import fieldlines_to_vtu, point_cloud_to_vtu, torus_surface_to_vtu, write_vtm, write_vtu
ts = time.strftime("%Y%m%d_%H%M%S")
outdir = Path("paraview") / f"gui_torus_vmec_opt_{ts}"
outdir.mkdir(parents=True, exist_ok=True)
V = self._cache.get("V")
s = self._cache.get("s")
Ktheta = self._cache.get("Ktheta")
Kphi = self._cache.get("Kphi")
Kmag = self._cache.get("Kmag")
if V is None or s is None or Ktheta is None or Kphi is None or Kmag is None:
print("ParaView export: no cached winding-surface solution yet.")
return
e_theta = np.asarray(self._e_theta, dtype=float)
e_phi = np.asarray(self._e_phi, dtype=float)
K_vec = Ktheta[..., None] * e_theta + Kphi[..., None] * e_phi
surf = write_vtu(
outdir / "winding_surface.vtu",
torus_surface_to_vtu(
surface=self.surface,
point_data={
"V": V.reshape(-1),
"s": s.reshape(-1),
"K": K_vec.reshape(-1, 3),
"Ktheta": Ktheta.reshape(-1),
"Kphi": Kphi.reshape(-1),
"|K|": Kmag.reshape(-1),
},
),
)
blocks: dict[str, str] = {"winding_surface": surf.name}
# Target points (VMEC surface) with Bn/B scalar if available.
tgt_pts = np.asarray(self._target_points, dtype=float)
pd = {
"n_hat": np.asarray(self._target_normals, dtype=float),
"weight": np.asarray(self._target_weights, dtype=float),
}
if self._target_Bn_over_B_cache is not None:
pd["Bn_over_B"] = np.asarray(self._target_Bn_over_B_cache, dtype=float)
tgt = write_vtu(outdir / "target_points.vtu", point_cloud_to_vtu(points=tgt_pts, point_data=pd))
blocks["target_points"] = tgt.name
active = np.flatnonzero(self.active > 0.0)
if active.size > 0:
xyz = torus_xyz(self.cfg.R0, self.cfg.a, self.theta_src[active], self.phi_src[active])
I = (
np.asarray(self._Iproj_cache, dtype=float)[active]
if self._Iproj_cache is not None
else np.asarray(self.currents_raw, dtype=float)[active] * float(self.current_scale)
)
elec = write_vtu(
outdir / "electrodes.vtu",
point_cloud_to_vtu(
points=np.asarray(xyz, dtype=float),
point_data={"I_A": I, "sign_I": np.sign(I)},
),
)
blocks["electrodes"] = elec.name
if self.show_fieldlines and self._traj_cache is not None:
traj_pv = np.transpose(self._traj_cache, (1, 0, 2))
fl = write_vtu(outdir / "fieldlines.vtu", fieldlines_to_vtu(traj=traj_pv))
blocks["fieldlines"] = fl.name
scene = write_vtm(outdir / "scene.vtm", blocks)
print(f"Saved ParaView scene: {scene}")
[docs]
def run(self) -> None:
print("Starting VMEC (B·n)/|B| optimization GUI. Close the window to exit.")
self._update_text()
self.window.Render()
self.interactor.Initialize()
self.interactor.Start()
[docs]
def run_torus_vmec_optimize_gui(*, cfg: VmecOptGUIConfig = VmecOptGUIConfig()) -> None: # pragma: no cover
_require_vtk()
print("Interactive torus VMEC (B·n)/|B| optimization GUI")
print(f" R0={cfg.R0} a={cfg.a} n_theta={cfg.n_theta} n_phi={cfg.n_phi}")
print(f" target: vmec_input={cfg.vmec_input} surf_n_theta={cfg.surf_n_theta} surf_n_phi={cfg.surf_n_phi}")
print(f" B0={cfg.B0}T sigma_theta={cfg.sigma_theta} sigma_phi={cfg.sigma_phi} sigma_s={cfg.sigma_s}")
print(f" opt: lr={cfg.lr} reg_currents={cfg.reg_currents} steps_per_opt={cfg.steps_per_opt}")
print(f" electrodes: init={cfg.n_electrodes_init}/{cfg.n_electrodes_max} init_current_raw_rms={cfg.init_current_raw_rms}")
print(f" mu0={float(MU0):.6e}")
app = TorusVmecBnOptimizeGUI(cfg)
app.run()