def _compute_constants(self, resolutions, spins): """Computes constants (class attributes). See constructor docstring.""" ells = [ sphere_utils.ell_max_from_resolution(res) for res in resolutions ] ell_max = max(ells) wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max) padded_deltas = [] for ell, delta in enumerate(wigner_deltas): padded_deltas.append( jnp.pad(delta, ((0, ell_max - ell), (ell_max - ell, ell_max - ell)))) self.wigner_deltas = jnp.stack(padded_deltas) self.quadrature_weights = { res: jnp.array(sphere_utils.torus_quadrature_weights(res)) for res in resolutions } self.swsft_forward_constants = {} for spin in spins: constants_spin = [] for ell in range(ell_max + 1): k_ell = sphere_utils.swsft_forward_constant( spin, ell, jnp.arange(-ell, ell + 1)) k_ell = jnp.asarray(k_ell) constants_spin.append(k_ell) self.swsft_forward_constants[spin] = coefficients_to_matrix( jnp.concatenate(constants_spin))
def _swsft_forward_single(sphere, spin, ell, m): r"""Computes a single SWSFT coefficient. Compute _sa_m^\ell, where s is the spin weight, ell the degree and m the order. Args: sphere: See swsft_forward_naive(). spin: See swsft_forward_naive(). ell: Degree (int). m: Order (int). Returns: A complex128 coefficient. """ # TODO(machc): The following could be simplified. See comment in # `_compute_Gnm_naive`. delta = sphere_utils.compute_wigner_delta(ell)[ell:, ell:] Jnm = _compute_Jnm(sphere, spin) # pylint: disable=invalid-name coeff = 0 for n in range(ell + 1): # n here is sometimes called m' if abs(spin) >= delta.shape[1]: break delta_s = delta[n, abs(spin)] delta_m = delta[n, abs(m)] if spin > 0: # index is (-s) delta_s *= (-1)**(ell + n) if m < 0: delta_m *= (-1)**(ell + n) coeff += delta_m * delta_s * Jnm[n, m] coeff *= sphere_utils.swsft_forward_constant(spin, ell, m) return coeff
def test_SpinSphericalFourierTransformer_forward_constants_matches_np( self, spin, ell): transformer = _get_transformer() ell_max = transformer.swsft_forward_constants[spin].shape[0] - 1 slice_ell = slice(ell_max - ell, ell_max + ell + 1) constants = transformer.swsft_forward_constants[spin][ell, slice_ell] constants_np = sphere_utils.swsft_forward_constant( spin, ell, jnp.arange(-ell, ell + 1)) self.assertAllClose(constants, constants_np)