Source code for torus_solver.surface_ops

from __future__ import annotations

import jax.numpy as jnp

from .spectral import spectral_derivative
from .torus import TorusSurface


[docs] def contravariant_components_torus(surface: TorusSurface, v_xyz: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Return contravariant components (v^θ, v^φ) of a tangential vector field on the torus. The torus parameterization has F=0 and metric coefficients E=a^2 and G=(R0+a cosθ)^2, so v · r_θ = E v^θ and v · r_φ = G v^φ. """ v_xyz = jnp.asarray(v_xyz) v_dot_rtheta = jnp.sum(v_xyz * surface.r_theta, axis=-1) v_dot_rphi = jnp.sum(v_xyz * surface.r_phi, axis=-1) v_theta = v_dot_rtheta / (surface.a * surface.a) v_phi = v_dot_rphi / surface.G return v_theta, v_phi
[docs] def surface_divergence_torus(surface: TorusSurface, v_xyz: jnp.ndarray) -> jnp.ndarray: """Surface divergence of a tangential vector field on the circular torus. Using contravariant components (v^θ, v^φ): div_s v = (1/sqrt(g)) [ ∂_θ (sqrt(g) v^θ) + ∂_φ (sqrt(g) v^φ) ]. """ v_theta, v_phi = contravariant_components_torus(surface, v_xyz) t1 = spectral_derivative(surface.sqrt_g * v_theta, surface.k_theta, axis=0) t2 = spectral_derivative(surface.sqrt_g * v_phi, surface.k_phi, axis=1) return (t1 + t2) / (surface.sqrt_g + 1e-30)