def test_coefficients_vector_to_matrix_has_right_shape(self, num_ell): num_coefficients = num_ell**2 coefficients = jnp.ones(num_coefficients) matrix = spin_spherical_harmonics.coefficients_to_matrix(coefficients) target_shape = (num_ell, 2 * num_ell - 1) self.assertEqual(matrix.shape, target_shape)
def test_swsft_backward_matches_np(self, num_coefficients, spin): transformer = _get_transformer() coeffs_np = (jnp.linspace(-1, 1, num_coefficients) + 1j * jnp.linspace(0, 1, num_coefficients)) coeffs_jax = spin_spherical_harmonics.coefficients_to_matrix(coeffs_np) sphere_np = np_spin_spherical_harmonics.swsft_backward_naive( coeffs_np, spin) sphere_jax = transformer.swsft_backward(coeffs_jax, spin) self.assertAllClose(sphere_np, sphere_jax)
def test_swsft_backward_matches_with_symmetry(self, num_coefficients, spin): transformer = _get_transformer() coeffs = (jnp.linspace(-1, 1, num_coefficients) + 1j * jnp.linspace(0, 1, num_coefficients)) coeffs = spin_spherical_harmonics.coefficients_to_matrix(coeffs) sphere = transformer.swsft_backward(coeffs, spin) with_symmetry = transformer.swsft_backward_with_symmetry(coeffs, spin) self.assertAllClose(sphere, with_symmetry)
def test_swsft_forward_matches_np(self, resolution, spin): transformer = _get_transformer() sphere = (jnp.linspace(-1, 1, resolution**2).reshape( (resolution, resolution))) coeffs_np = np_spin_spherical_harmonics.swsft_forward_naive( sphere, spin) coeffs_jax = transformer.swsft_forward(sphere, spin) self.assertAllClose( coeffs_jax, spin_spherical_harmonics.coefficients_to_matrix(coeffs_np))
def test_swsft_backward_validate_raises(self): """Check that swsft_backward() raises exception if constants are invalid.""" transformer = spin_spherical_harmonics.SpinSphericalFourierTransformer( resolutions=[4, 8], spins=(0, 1)) n_coeffs = 64 # Corresponds to resolution == 16. # Wrong ell_max, right spin: coeffs_np = jnp.linspace(-1, 1, n_coeffs) + 1j * jnp.linspace(0, 1, n_coeffs) coeffs_jax = spin_spherical_harmonics.coefficients_to_matrix(coeffs_np) self.assertRaises(ValueError, transformer.swsft_backward, coeffs_jax, spin=1) n_coeffs = 16 # Corresponds to resolution == 8. # Right ell_max, wrong spin: coeffs_np = jnp.linspace(-1, 1, n_coeffs) + 1j * jnp.linspace(0, 1, n_coeffs) coeffs_jax = spin_spherical_harmonics.coefficients_to_matrix(coeffs_np) self.assertRaises(ValueError, transformer.swsft_backward, coeffs_jax, spin=-1)