def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] h = (jnp.einsum( 'ij,ijk->ijk', coeff_0, jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll( p_val, shift=2, axis=1))) p_val = p_val + h return p_val
def _gen_recurrence_mask( l_max: int, is_normalized: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]: """Generates mask for recurrence relation on the remaining entries. The remaining entries are with respect to the diagonal and offdiagonal entries. Args: l_max: see `gen_normalized_legendre`. is_normalized: True if the recurrence mask is used by normalized associated Legendre functions. Returns: Arrays representing the mask used by the recurrence relations. """ # Computes all coefficients. m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1] if is_normalized: c0 = l_mat * l_mat c1 = m_mat * m_mat c2 = 2.0 * l_mat c3 = (l_mat - 1.0) * (l_mat - 1.0) d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) else: d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) d0_mask_indices = jnp.triu_indices(l_max + 1, 1) d1_mask_indices = jnp.triu_indices(l_max + 1, 2) d_zeros = jnp.zeros((l_max + 1, l_max + 1)) d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices]) d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices]) # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. # i = jnp.arange(l_max + 1)[:, None, None] # j = jnp.arange(l_max + 1)[None, :, None] # k = jnp.arange(l_max + 1)[None, None, :] i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1] mask = 1.0 * (i + j - k == 0) d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask) d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask) return (d0_mask_3d, d1_mask_3d)
def _gen_associated_legendre(l_max: int, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: r"""Computes associated Legendre functions (ALFs) of the first kind. The ALFs of the first kind are used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, repectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three recurrence relations are used in the computation. Args: l_max: The maximum degree of the associated Legendre function. Both the degrees and orders are `[0, 1, 2, ..., l_max]`. x: A vector of type `float32`, `float64` containing the sampled points in spherical coordinates, at which the ALFs are computed; `x` is essentially `cos(θ)`. For the numerical integration used by the spherical harmonics transforms, `x` contains the quadrature points in the interval of `[-1, 1]`. There are several approaches to provide the quadrature points: Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev method (`scipy.special.roots_chebyu`), and Driscoll & Healy method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier transforms and convolutions on the 2-sphere." Advances in applied mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature points are nearly equal-spaced along θ and provide exact discrete orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The Driscoll & Healy qudarture points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics transform. is_normalized: True if the associated Legendre functions are normalized. With normalization, `N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and evalution points. """ p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0])) a_idx = jnp.arange(1, l_max + 1) b_idx = jnp.arange(l_max) if is_normalized: initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0). f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx)) f_b = jnp.sqrt(2.0 * b_idx + 3.0) else: initial_value = 1.0 # The initial value p(0,0). f_a = jnp.cumprod(1.0 - 2.0 * a_idx) f_b = 2.0 * b_idx + 1.0 p = p.at[(0, 0)].set(initial_value) # Compute the diagonal entries p(l,l) with recurrence. y = jnp.cumprod(jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])), axis=0) p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y) diag_indices = jnp.diag_indices(l_max + 1) p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag) # Compute the off-diagonal entries with recurrence. p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x), p[jnp.diag_indices(l_max)]) offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1) p = p.at[offdiag_indices].set(p_offdiag) # Compute the remaining entries with recurrence. d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(l_max, is_normalized=is_normalized) def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] h = (jnp.einsum( 'ij,ijk->ijk', coeff_0, jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll( p_val, shift=2, axis=1))) p_val = p_val + h return p_val if l_max > 1: p = lax.fori_loop(lower=2, upper=l_max + 1, body_fun=body_fun, init_val=p) return p
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: """Generates derivatives of associated Legendre functions of the first kind. Args: p: The 3D array containing the values of associated Legendre functions; the dimensions are in the sequence of order (m), degree (l), and evalution points. x: A vector of type `float32` or `float64` containing the sampled points. is_normalized: True if the associated Legendre functions are normalized. Returns: The 3D array representing the derivatives of associated Legendre functions of the first kind. """ num_m, num_l, num_x = p.shape # p_{l-1}^m. p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :] # p_{l-1}^{m+2}. p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :] # p_{l-1}^{m-2}. p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :] # Derivative computation requires negative orders. if is_normalized: raise NotImplementedError( 'Negative orders for normalization is not implemented yet.') else: if num_l > 1: l_vec = jnp.arange(1, num_l - 1) p_p1 = p[1, 1:num_l - 1, :] coeff = -1.0 / ((l_vec + 1) * l_vec) update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1) if num_l > 2: l_vec = jnp.arange(2, num_l - 1) p_p2 = p[2, 2:num_l - 1, :] coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec) update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2) m_mat, l_mat = jnp.mgrid[:num_m, :num_l] coeff_zeros = jnp.zeros((num_m, num_l)) upper_0_indices = jnp.triu_indices(num_m, 0, num_l) zero_vec = jnp.zeros((num_l, )) a0 = -0.5 / (m_mat - 1.0) a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices]) a0_masked = a0_masked.at[1, :].set(zero_vec) b0 = l_mat + m_mat c0 = a0 * (b0 - 2.0) * (b0 - 1.0) c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices]) c0_masked = c0_masked.at[1, :].set(zero_vec) # p_l^{m-1}. p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) d0 = -0.5 / (m_mat + 1.0) d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices]) e0 = d0 * b0 * (b0 + 1.0) e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices]) # p_l^{m+1}. p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) f0 = b0 * (l_mat - m_mat + 1.0) / 2.0 f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices]) p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l # Special treatment of the singularity at m = 1. if num_m > 1: l_vec = jnp.arange(num_l) g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) if num_l > 2: g0 = g0 - p[2, :, :] p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0) p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, ))) return p_derivative