def test_coefficients_vector_to_matrix_has_right_shape(self, num_ell):
        num_coefficients = num_ell**2
        coefficients = jnp.ones(num_coefficients)
        matrix = spin_spherical_harmonics.coefficients_to_matrix(coefficients)
        target_shape = (num_ell, 2 * num_ell - 1)

        self.assertEqual(matrix.shape, target_shape)
    def test_swsft_backward_matches_np(self, num_coefficients, spin):
        transformer = _get_transformer()
        coeffs_np = (jnp.linspace(-1, 1, num_coefficients) +
                     1j * jnp.linspace(0, 1, num_coefficients))
        coeffs_jax = spin_spherical_harmonics.coefficients_to_matrix(coeffs_np)
        sphere_np = np_spin_spherical_harmonics.swsft_backward_naive(
            coeffs_np, spin)
        sphere_jax = transformer.swsft_backward(coeffs_jax, spin)

        self.assertAllClose(sphere_np, sphere_jax)
    def test_swsft_backward_matches_with_symmetry(self, num_coefficients,
                                                  spin):
        transformer = _get_transformer()
        coeffs = (jnp.linspace(-1, 1, num_coefficients) +
                  1j * jnp.linspace(0, 1, num_coefficients))
        coeffs = spin_spherical_harmonics.coefficients_to_matrix(coeffs)
        sphere = transformer.swsft_backward(coeffs, spin)
        with_symmetry = transformer.swsft_backward_with_symmetry(coeffs, spin)

        self.assertAllClose(sphere, with_symmetry)
    def test_swsft_forward_matches_np(self, resolution, spin):
        transformer = _get_transformer()
        sphere = (jnp.linspace(-1, 1, resolution**2).reshape(
            (resolution, resolution)))
        coeffs_np = np_spin_spherical_harmonics.swsft_forward_naive(
            sphere, spin)
        coeffs_jax = transformer.swsft_forward(sphere, spin)

        self.assertAllClose(
            coeffs_jax,
            spin_spherical_harmonics.coefficients_to_matrix(coeffs_np))
    def test_swsft_backward_validate_raises(self):
        """Check that swsft_backward() raises exception if constants are invalid."""
        transformer = spin_spherical_harmonics.SpinSphericalFourierTransformer(
            resolutions=[4, 8], spins=(0, 1))
        n_coeffs = 64  # Corresponds to resolution == 16.
        # Wrong ell_max, right spin:
        coeffs_np = jnp.linspace(-1, 1,
                                 n_coeffs) + 1j * jnp.linspace(0, 1, n_coeffs)
        coeffs_jax = spin_spherical_harmonics.coefficients_to_matrix(coeffs_np)
        self.assertRaises(ValueError,
                          transformer.swsft_backward,
                          coeffs_jax,
                          spin=1)

        n_coeffs = 16  # Corresponds to resolution == 8.
        # Right ell_max, wrong spin:
        coeffs_np = jnp.linspace(-1, 1,
                                 n_coeffs) + 1j * jnp.linspace(0, 1, n_coeffs)
        coeffs_jax = spin_spherical_harmonics.coefficients_to_matrix(coeffs_np)
        self.assertRaises(ValueError,
                          transformer.swsft_backward,
                          coeffs_jax,
                          spin=-1)