Source code for torus_solver.spectral

from __future__ import annotations

import jax.numpy as jnp


[docs] def make_wavenumbers(n: int, period: float = 2 * jnp.pi) -> jnp.ndarray: """Return Fourier wavenumbers (rad⁻¹) for a periodic grid of length `period`.""" dx = period / n return (2 * jnp.pi) * jnp.fft.fftfreq(n, d=dx)
[docs] def spectral_derivative(f: jnp.ndarray, k: jnp.ndarray, *, axis: int) -> jnp.ndarray: """Compute ∂f/∂x for a periodic coordinate using FFTs.""" F = jnp.fft.fft(f, axis=axis) shape = [1] * f.ndim shape[axis] = k.shape[0] kk = k.reshape(shape) df = jnp.fft.ifft(1j * kk * F, axis=axis) return df.real
[docs] def spectral_second_derivative(f: jnp.ndarray, k: jnp.ndarray, *, axis: int) -> jnp.ndarray: """Compute ∂²f/∂x² for a periodic coordinate using FFTs.""" F = jnp.fft.fft(f, axis=axis) shape = [1] * f.ndim shape[axis] = k.shape[0] kk = k.reshape(shape) d2f = jnp.fft.ifft(-(kk**2) * F, axis=axis) return d2f.real