def _compute_constants(self, resolutions, spins): """Computes constants (class attributes). See constructor docstring.""" ells = [ sphere_utils.ell_max_from_resolution(res) for res in resolutions ] ell_max = max(ells) wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max) padded_deltas = [] for ell, delta in enumerate(wigner_deltas): padded_deltas.append( jnp.pad(delta, ((0, ell_max - ell), (ell_max - ell, ell_max - ell)))) self.wigner_deltas = jnp.stack(padded_deltas) self.quadrature_weights = { res: jnp.array(sphere_utils.torus_quadrature_weights(res)) for res in resolutions } self.swsft_forward_constants = {} for spin in spins: constants_spin = [] for ell in range(ell_max + 1): k_ell = sphere_utils.swsft_forward_constant( spin, ell, jnp.arange(-ell, ell + 1)) k_ell = jnp.asarray(k_ell) constants_spin.append(k_ell) self.swsft_forward_constants[spin] = coefficients_to_matrix( jnp.concatenate(constants_spin))
def _extend_sphere_fft(sphere, spin): """Applies 2D FFT to a spherical function by extending it to a torus. Args: sphere: See swsft_forward_naive(). spin: See swsft_forward_naive(). Returns: Matrix of complex128 Fourier coefficients. If the input shape is (n, n), the output will be (2*n-2, n). Raises: ValueError: If input dimensions are not even. """ n = sphere.shape[1] if n % 2 != 0: raise ValueError("Input sphere must have even height!") torus = (-1)**spin * np.roll(sphere[1:-1][::-1], n // 2, axis=1) torus = np.concatenate([sphere, torus], axis=0) weights = sphere_utils.torus_quadrature_weights(n) torus = weights[:, None] * torus coeffs = np.fft.fft2(torus) * 2 * np.pi / n return coeffs
def test_torus_quadrature_weights_curve(self, resolution): """Checks that quadrature weights follow the curve in H&W, Figure 5. The first half of the weights corresponds to the original spherical function and the values resemble the naive sin(colatitude) quadrature rule: small near poles, max near equator. The second half consists of the extension to torus and has small weights. Args: resolution: int, original spherical resolution. Returns: None. """ weights = sphere_utils.torus_quadrature_weights(resolution) # The first part of the weights has an increasing-decreasing pattern. increasing = np.diff(weights[:resolution]) > 0 self.assertTrue(increasing[:resolution // 2 - 1].all()) self.assertFalse(increasing[resolution // 2 - 1:].all()) # The second part is the extension and has much lower weights. self.assertGreater(weights[:resolution].sum(), weights[resolution:].sum()) # The weights must sum up to 2, as the integral of sin(x) from 0 to pi. self.assertAllClose(weights.sum(), 2.)