def _compute_constants(self, resolutions, spins): """Computes constants (class attributes). See constructor docstring.""" ells = [ sphere_utils.ell_max_from_resolution(res) for res in resolutions ] ell_max = max(ells) wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max) padded_deltas = [] for ell, delta in enumerate(wigner_deltas): padded_deltas.append( jnp.pad(delta, ((0, ell_max - ell), (ell_max - ell, ell_max - ell)))) self.wigner_deltas = jnp.stack(padded_deltas) self.quadrature_weights = { res: jnp.array(sphere_utils.torus_quadrature_weights(res)) for res in resolutions } self.swsft_forward_constants = {} for spin in spins: constants_spin = [] for ell in range(ell_max + 1): k_ell = sphere_utils.swsft_forward_constant( spin, ell, jnp.arange(-ell, ell + 1)) k_ell = jnp.asarray(k_ell) constants_spin.append(k_ell) self.swsft_forward_constants[spin] = coefficients_to_matrix( jnp.concatenate(constants_spin))
def _compute_Gnm_naive(coeffs, spin): # pylint: disable=invalid-name r"""Compute Gnm (not vectorized). The matrix Gnm, defined in H&W, Equation (13), is closely related to the 2D Fourier transform of a spin-weighted spherical function. Gnm = (-1)^s i^(m+s) \sum_\ell c \Delta_{-n,-s} \Delta_{-n,m} _sa_m^\ell, where c = \sqrt{(2\ell + 1) / (4\pi)}, and _sa_m^\ell is the coefficient at (ell, m). Args: coeffs: See swsft_backward_naive(). spin: See swsft_backward_naive(). Returns: The complex128 matrix Gnm. If coeffs has n**2 elements, the output is (2*n-1, 2*n-1). Raises: ValueError: If len(coeffs) is not a perfect square. """ ell_max = sphere_utils.ell_max_from_n_coeffs(len(coeffs)) deltas = sphere_utils.compute_all_wigner_delta(ell_max) # TODO(machc): This could be simplified. Previously we only stored # non-negative n in the Wigner Deltas, and used symmetries to # complete the result. Now that the complete Deltas are stored, we # could simplify the code below, but for now we just revert `deltas` # to what was stored before and keep using the prior implementation. deltas = tuple([delta[ell:] for ell, delta in enumerate(deltas)]) Gnm = np.zeros((ell_max + 1, 2 * ell_max + 1), dtype=np.complex128) # pylint: disable=invalid-name for ell in range(abs(spin), ell_max + 1): factor = np.sqrt((2 * ell + 1) / 4 / np.pi) for m in range(-ell, ell + 1): # The following also fixes the signs because deltas should be evaluated at # negative n but we only store values for positive n. phase = (1j)**(m + spin) * (-1)**m index = _get_swsft_coeff_index(ell, m) Gnm[:ell + 1, ell_max + m] += (phase * factor * deltas[ell][:, ell - spin] * deltas[ell][:, ell + m] * coeffs[index]) # Use symmetry to obtain entries for negative n. signs = (-1.)**(spin + np.arange(-ell_max, ell_max + 1))[None, :] return np.concatenate([signs * Gnm[1:][::-1], Gnm])
def test_compute_all_wigner_delta_matches_single(self, ell_max, ell): wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max) wigner_delta = sphere_utils.compute_wigner_delta(ell) self.assertAllClose(wigner_deltas[ell], wigner_delta)