def zeta(x, q=None): assert q is not None, "Riemann zeta function is not implemented yet." # Reference: Johansson, Fredrik. # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives." # Numerical Algorithms 69.2 (2015): 253-270. # https://arxiv.org/abs/1309.2877 - formula (5) # here we keep the same notation as in reference s, a = _promote_args_inexact("zeta", x, q) dtype = lax.dtype(a).type s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1) # precision ~ N, M N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16) assert M <= len(_BERNOULLI_COEFS) k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim))) S = jnp.sum((a_ + k)**-s_, -1) I = lax.div((a + N)**(dtype(1) - s), s - dtype(1)) T0 = (a + N)**-s m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) s_over_a = (s_ + m) / (a_ + N) T1 = jnp.cumprod(s_over_a, -1)[..., ::2] T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max) coefs = np.expand_dims( np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), tuple(range(a.ndim))) T1 = T1 / coefs T = T0 * (dtype(0.5) + T1.sum(-1)) return S + I + T
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], )) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate( (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x
def _gen_associated_legendre(l_max: int, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: r"""Computes associated Legendre functions (ALFs) of the first kind. The ALFs of the first kind are used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, repectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three recurrence relations are used in the computation. Args: l_max: The maximum degree of the associated Legendre function. Both the degrees and orders are `[0, 1, 2, ..., l_max]`. x: A vector of type `float32`, `float64` containing the sampled points in spherical coordinates, at which the ALFs are computed; `x` is essentially `cos(θ)`. For the numerical integration used by the spherical harmonics transforms, `x` contains the quadrature points in the interval of `[-1, 1]`. There are several approaches to provide the quadrature points: Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev method (`scipy.special.roots_chebyu`), and Driscoll & Healy method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier transforms and convolutions on the 2-sphere." Advances in applied mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature points are nearly equal-spaced along θ and provide exact discrete orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The Driscoll & Healy qudarture points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics transform. is_normalized: True if the associated Legendre functions are normalized. With normalization, `N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and evalution points. """ p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0])) a_idx = jnp.arange(1, l_max + 1) b_idx = jnp.arange(l_max) if is_normalized: initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0). f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx)) f_b = jnp.sqrt(2.0 * b_idx + 3.0) else: initial_value = 1.0 # The initial value p(0,0). f_a = jnp.cumprod(1.0 - 2.0 * a_idx) f_b = 2.0 * b_idx + 1.0 p = p.at[(0, 0)].set(initial_value) # Compute the diagonal entries p(l,l) with recurrence. y = jnp.cumprod(jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])), axis=0) p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y) diag_indices = jnp.diag_indices(l_max + 1) p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag) # Compute the off-diagonal entries with recurrence. p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x), p[jnp.diag_indices(l_max)]) offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1) p = p.at[offdiag_indices].set(p_offdiag) # Compute the remaining entries with recurrence. d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(l_max, is_normalized=is_normalized) def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] h = (jnp.einsum( 'ij,ijk->ijk', coeff_0, jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll( p_val, shift=2, axis=1))) p_val = p_val + h return p_val if l_max > 1: p = lax.fori_loop(lower=2, upper=l_max + 1, body_fun=body_fun, init_val=p) return p