def _compute_constants(self, resolutions, spins):
        """Computes constants (class attributes). See constructor docstring."""
        ells = [
            sphere_utils.ell_max_from_resolution(res) for res in resolutions
        ]
        ell_max = max(ells)
        wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max)
        padded_deltas = []
        for ell, delta in enumerate(wigner_deltas):
            padded_deltas.append(
                jnp.pad(delta,
                        ((0, ell_max - ell), (ell_max - ell, ell_max - ell))))
        self.wigner_deltas = jnp.stack(padded_deltas)

        self.quadrature_weights = {
            res: jnp.array(sphere_utils.torus_quadrature_weights(res))
            for res in resolutions
        }

        self.swsft_forward_constants = {}
        for spin in spins:
            constants_spin = []
            for ell in range(ell_max + 1):
                k_ell = sphere_utils.swsft_forward_constant(
                    spin, ell, jnp.arange(-ell, ell + 1))
                k_ell = jnp.asarray(k_ell)
                constants_spin.append(k_ell)
            self.swsft_forward_constants[spin] = coefficients_to_matrix(
                jnp.concatenate(constants_spin))
def _compute_Gnm_naive(coeffs, spin):  # pylint: disable=invalid-name
    r"""Compute Gnm (not vectorized).

  The matrix Gnm, defined in H&W, Equation (13), is closely related to the 2D
  Fourier transform of a spin-weighted spherical function.

  Gnm = (-1)^s i^(m+s) \sum_\ell c \Delta_{-n,-s} \Delta_{-n,m} _sa_m^\ell,
  where c = \sqrt{(2\ell + 1) / (4\pi)}, and _sa_m^\ell is the coefficient at
  (ell, m).

  Args:
    coeffs: See swsft_backward_naive().
    spin: See swsft_backward_naive().

  Returns:
    The complex128 matrix Gnm. If coeffs has n**2 elements, the output is
    (2*n-1, 2*n-1).

  Raises:
    ValueError: If len(coeffs) is not a perfect square.
  """
    ell_max = sphere_utils.ell_max_from_n_coeffs(len(coeffs))
    deltas = sphere_utils.compute_all_wigner_delta(ell_max)
    # TODO(machc): This could be simplified. Previously we only stored
    # non-negative n in the Wigner Deltas, and used symmetries to
    # complete the result. Now that the complete Deltas are stored, we
    # could simplify the code below, but for now we just revert `deltas`
    # to what was stored before and keep using the prior implementation.
    deltas = tuple([delta[ell:] for ell, delta in enumerate(deltas)])
    Gnm = np.zeros((ell_max + 1, 2 * ell_max + 1), dtype=np.complex128)  # pylint: disable=invalid-name
    for ell in range(abs(spin), ell_max + 1):
        factor = np.sqrt((2 * ell + 1) / 4 / np.pi)
        for m in range(-ell, ell + 1):
            # The following also fixes the signs because deltas should be evaluated at
            # negative n but we only store values for positive n.
            phase = (1j)**(m + spin) * (-1)**m
            index = _get_swsft_coeff_index(ell, m)
            Gnm[:ell + 1,
                ell_max + m] += (phase * factor * deltas[ell][:, ell - spin] *
                                 deltas[ell][:, ell + m] * coeffs[index])
    # Use symmetry to obtain entries for negative n.
    signs = (-1.)**(spin + np.arange(-ell_max, ell_max + 1))[None, :]
    return np.concatenate([signs * Gnm[1:][::-1], Gnm])
 def test_compute_all_wigner_delta_matches_single(self, ell_max, ell):
   wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max)
   wigner_delta = sphere_utils.compute_wigner_delta(ell)
   self.assertAllClose(wigner_deltas[ell], wigner_delta)