def _compute_constants(self, resolutions, spins):
        """Computes constants (class attributes). See constructor docstring."""
        ells = [
            sphere_utils.ell_max_from_resolution(res) for res in resolutions
        ]
        ell_max = max(ells)
        wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max)
        padded_deltas = []
        for ell, delta in enumerate(wigner_deltas):
            padded_deltas.append(
                jnp.pad(delta,
                        ((0, ell_max - ell), (ell_max - ell, ell_max - ell))))
        self.wigner_deltas = jnp.stack(padded_deltas)

        self.quadrature_weights = {
            res: jnp.array(sphere_utils.torus_quadrature_weights(res))
            for res in resolutions
        }

        self.swsft_forward_constants = {}
        for spin in spins:
            constants_spin = []
            for ell in range(ell_max + 1):
                k_ell = sphere_utils.swsft_forward_constant(
                    spin, ell, jnp.arange(-ell, ell + 1))
                k_ell = jnp.asarray(k_ell)
                constants_spin.append(k_ell)
            self.swsft_forward_constants[spin] = coefficients_to_matrix(
                jnp.concatenate(constants_spin))
def _extend_sphere_fft(sphere, spin):
    """Applies 2D FFT to a spherical function by extending it to a torus.

  Args:
    sphere: See swsft_forward_naive().
    spin: See swsft_forward_naive().

  Returns:
    Matrix of complex128 Fourier coefficients. If the input shape is (n, n), the
    output will be (2*n-2, n).

  Raises:
    ValueError: If input dimensions are not even.
  """
    n = sphere.shape[1]
    if n % 2 != 0:
        raise ValueError("Input sphere must have even height!")
    torus = (-1)**spin * np.roll(sphere[1:-1][::-1], n // 2, axis=1)
    torus = np.concatenate([sphere, torus], axis=0)
    weights = sphere_utils.torus_quadrature_weights(n)
    torus = weights[:, None] * torus
    coeffs = np.fft.fft2(torus) * 2 * np.pi / n

    return coeffs
  def test_torus_quadrature_weights_curve(self, resolution):
    """Checks that quadrature weights follow the curve in H&W, Figure 5.

    The first half of the weights corresponds to the original spherical function
    and the values resemble the naive sin(colatitude) quadrature rule: small
    near poles, max near equator. The second half consists of the extension to
    torus and has small weights.

    Args:
      resolution: int, original spherical resolution.

    Returns:
      None.
    """
    weights = sphere_utils.torus_quadrature_weights(resolution)
    # The first part of the weights has an increasing-decreasing pattern.
    increasing = np.diff(weights[:resolution]) > 0
    self.assertTrue(increasing[:resolution // 2 - 1].all())
    self.assertFalse(increasing[resolution // 2 - 1:].all())
    # The second part is the extension and has much lower weights.
    self.assertGreater(weights[:resolution].sum(),
                       weights[resolution:].sum())
    # The weights must sum up to 2, as the integral of sin(x) from 0 to pi.
    self.assertAllClose(weights.sum(), 2.)