def _compute_constants(self, resolutions, spins):
        """Computes constants (class attributes). See constructor docstring."""
        ells = [
            sphere_utils.ell_max_from_resolution(res) for res in resolutions
        ]
        ell_max = max(ells)
        wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max)
        padded_deltas = []
        for ell, delta in enumerate(wigner_deltas):
            padded_deltas.append(
                jnp.pad(delta,
                        ((0, ell_max - ell), (ell_max - ell, ell_max - ell))))
        self.wigner_deltas = jnp.stack(padded_deltas)

        self.quadrature_weights = {
            res: jnp.array(sphere_utils.torus_quadrature_weights(res))
            for res in resolutions
        }

        self.swsft_forward_constants = {}
        for spin in spins:
            constants_spin = []
            for ell in range(ell_max + 1):
                k_ell = sphere_utils.swsft_forward_constant(
                    spin, ell, jnp.arange(-ell, ell + 1))
                k_ell = jnp.asarray(k_ell)
                constants_spin.append(k_ell)
            self.swsft_forward_constants[spin] = coefficients_to_matrix(
                jnp.concatenate(constants_spin))
def _swsft_forward_single(sphere, spin, ell, m):
    r"""Computes a single SWSFT coefficient.

  Compute _sa_m^\ell, where s is the spin weight, ell the degree and m the
  order.

  Args:
    sphere: See swsft_forward_naive().
    spin: See swsft_forward_naive().
    ell: Degree (int).
    m: Order (int).

  Returns:
    A complex128 coefficient.

  """
    # TODO(machc): The following could be simplified. See comment in
    # `_compute_Gnm_naive`.
    delta = sphere_utils.compute_wigner_delta(ell)[ell:, ell:]
    Jnm = _compute_Jnm(sphere, spin)  # pylint: disable=invalid-name
    coeff = 0
    for n in range(ell + 1):  # n here is sometimes called m'
        if abs(spin) >= delta.shape[1]:
            break
        delta_s = delta[n, abs(spin)]
        delta_m = delta[n, abs(m)]
        if spin > 0:  # index is (-s)
            delta_s *= (-1)**(ell + n)
        if m < 0:
            delta_m *= (-1)**(ell + n)
        coeff += delta_m * delta_s * Jnm[n, m]

    coeff *= sphere_utils.swsft_forward_constant(spin, ell, m)

    return coeff
 def test_SpinSphericalFourierTransformer_forward_constants_matches_np(
         self, spin, ell):
     transformer = _get_transformer()
     ell_max = transformer.swsft_forward_constants[spin].shape[0] - 1
     slice_ell = slice(ell_max - ell, ell_max + ell + 1)
     constants = transformer.swsft_forward_constants[spin][ell, slice_ell]
     constants_np = sphere_utils.swsft_forward_constant(
         spin, ell, jnp.arange(-ell, ell + 1))
     self.assertAllClose(constants, constants_np)