Source code for torus_solver.biot_savart
from __future__ import annotations
import jax
import jax.numpy as jnp
from .torus import TorusSurface
MU0 = 4e-7 * jnp.pi
[docs]
def biot_savart_surface(
surface: TorusSurface,
K: jnp.ndarray,
eval_points: jnp.ndarray,
*,
mu0: float = MU0,
eps: float = 1e-9,
chunk_size: int | None = 256,
) -> jnp.ndarray:
"""Magnetic field B (Tesla) from a surface current density K (A/m).
In continuous form:
B(x) = μ0/(4π) ∬ K(r') × (x-r') / norm(x-r')^3 dA'
Notes:
- `eps` is a small softening length to avoid numerical issues when points are
extremely close to the surface.
- This routine is JAX-differentiable w.r.t. `K` and `eval_points`.
"""
eval_points = jnp.asarray(eval_points)
r_surf = surface.r.reshape((-1, 3))
K_surf = jnp.asarray(K).reshape((-1, 3))
dA = surface.area_weights.reshape((-1,))
pref = mu0 / (4 * jnp.pi)
eps2 = eps * eps
def field_chunk(P: jnp.ndarray) -> jnp.ndarray:
# P: (N,3)
R = P[:, None, :] - r_surf[None, :, :] # (N,M,3)
r2 = jnp.sum(R * R, axis=-1) + eps2 # (N,M)
inv_r3 = r2 ** (-1.5)
dB = jnp.cross(K_surf[None, :, :], R) * inv_r3[..., None] * dA[None, :, None]
return pref * jnp.sum(dB, axis=1) # (N,3)
# For small point counts, the vmap path is fine; for larger problems, chunked
# evaluation avoids allocating an (N_eval, N_surface, 3) tensor that can become
# large in double precision.
n_eval = eval_points.shape[0]
if chunk_size is None or n_eval <= chunk_size:
return field_chunk(eval_points)
n_chunks = (n_eval + chunk_size - 1) // chunk_size
pad = n_chunks * chunk_size - n_eval
eval_pad = jnp.pad(eval_points, ((0, pad), (0, 0)))
eval_chunks = eval_pad.reshape((n_chunks, chunk_size, 3))
def scan_step(_, pts):
return None, field_chunk(pts)
_, B_chunks = jax.lax.scan(scan_step, None, eval_chunks)
B_pad = B_chunks.reshape((n_chunks * chunk_size, 3))
return B_pad[:n_eval, :]