def swsft_forward(self, sphere, spin): """Spin-weighted spherical harmonics transform (fast JAX version). Returns coefficients in zero-padded format: [0 0 c00 0 0 ] [0 c1m1 c10 c11 0 ] [c2m2 c2m1 c20 c21 c22] See also: np_spin_spherical_harmonics.swsft_forward_naive(). Args: sphere: A (n, n) array representing a spherical function. Equirectangular sampling, lat, long order. spin: Spin weight. Returns: A (n//2, n-1) array of complex64 coefficients. The coefficient at degree ell and order m is at position [ell, ell_max+m]. """ if not self.validate(resolution=sphere.shape[0], spin=spin): raise ValueError("Constants are invalid for given input!") ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) Jnm = self._compute_Jnm(sphere, spin) # pylint: disable=invalid-name Jnm = jnp.concatenate([Jnm[:, -ell_max:], Jnm[:, :ell_max + 1]], axis=1) # pylint: disable=invalid-name deltas = self._slice_wigner_deltas(ell_max) deltas = deltas * deltas[Ellipsis, ell_max - spin][Ellipsis, None] forward_constants = self._slice_forward_constants(ell_max, spin) return jnp.einsum("ik,ijk,jk->ik", forward_constants, deltas, Jnm)
def _compute_constants(self, resolutions, spins): """Computes constants (class attributes). See constructor docstring.""" ells = [ sphere_utils.ell_max_from_resolution(res) for res in resolutions ] ell_max = max(ells) wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max) padded_deltas = [] for ell, delta in enumerate(wigner_deltas): padded_deltas.append( jnp.pad(delta, ((0, ell_max - ell), (ell_max - ell, ell_max - ell)))) self.wigner_deltas = jnp.stack(padded_deltas) self.quadrature_weights = { res: jnp.array(sphere_utils.torus_quadrature_weights(res)) for res in resolutions } self.swsft_forward_constants = {} for spin in spins: constants_spin = [] for ell in range(ell_max + 1): k_ell = sphere_utils.swsft_forward_constant( spin, ell, jnp.arange(-ell, ell + 1)) k_ell = jnp.asarray(k_ell) constants_spin.append(k_ell) self.swsft_forward_constants[spin] = coefficients_to_matrix( jnp.concatenate(constants_spin))
def _compute_Inm(sphere, spin): # pylint: disable=invalid-name r"""Computes Inm. This evaluates Inm = \int_{S^2}e^{-in\theta}e^{-im\phi} f(\theta, \phi) sin\theta d\theta d\phi, as defined in H&W, Equation (8). It is used in intermediate steps of the SWSFT computation. Args: sphere: See swsft_forward_naive(). spin: See swsft_forward_naive(). Returns: The complex128 matrix Inm. If sphere is (n, n), output will be (n-1, n-1). """ ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) coeffs = _extend_sphere_fft(sphere, spin=spin) rows1 = np.concatenate( [coeffs[:ell_max + 1, :ell_max + 1], coeffs[:ell_max + 1, -ell_max:]], axis=1) rows2 = np.concatenate( [coeffs[-ell_max:, :ell_max + 1], coeffs[-ell_max:, -ell_max:]], axis=1) return np.concatenate([rows1, rows2], axis=0)
def _compute_Jnm(sphere, spin): # pylint: disable=invalid-name """Computes Jnm (trimmed version of Inm). Jnm = I0m for n=0 = Inm + (-1)^{m+s}I_(-n)m for n>0 This matrix is defined in H&W, Equation (10). Args: sphere: See swsft_forward_naive(). spin: See swsft_forward_naive(). Returns: The complex128 matrix Jnm. If sphere is (n, n), output will be (n//2, n-1). Raises: ValueError: If sphere rank is not 2. """ if sphere.ndim != 2: raise ValueError("Input sphere rank must be 2.") ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) Inm = _compute_Inm(sphere, spin=spin) # pylint: disable=invalid-name # now make a matrix with (-1)^{m+s} columns m = np.concatenate( [np.arange(ell_max + 1), -np.arange(ell_max + 1)[1:][::-1]])[None] signs = (-1.)**(m + spin) # only takes positive n Jnm = Inm[:ell_max + 1].copy() # pylint: disable=invalid-name # make n = -n rowwise Jnm[1:] += signs * Inm[-ell_max:][::-1] return Jnm
def __call__(self, sphere_set): """Applies convolution to inputs. Args: sphere_set: A (batch_size, resolution, resolution, n_spins_in, n_channels_in) array of spin-weighted spherical functions (SWSF) with equiangular sampling. Returns: A (batch_size, resolution, resolution, n_spins_out, n_channels_out) complex64 array of SWSF with equiangular H&W sampling. """ resolution = sphere_set.shape[1] if sphere_set.shape[2] != resolution: raise ValueError("Axes 1 and 2 must have the same dimensions!") if sphere_set.shape[3] != len(list(self.spins_in)): raise ValueError("Input axis 3 (spins_in) doesn't match layer's.") # Make sure constants contain all spins for input resolution. for spin in set(self.spins_in).union(self.spins_out): if not self.transformer.validate(resolution, spin): raise ValueError("Constants are invalid for given input!") ell_max = sphere_utils.ell_max_from_resolution(resolution) num_channels_in = sphere_set.shape[-1] if self.num_filter_params is None: kernel = self._get_kernel(ell_max, num_channels_in) else: kernel = self._get_localized_kernel(ell_max, num_channels_in) # Map over the batch dimension. vmap_convolution = jax.vmap(_swsconv_spatial_spectral, in_axes=(None, 0, None, None, None)) return vmap_convolution(self.transformer, sphere_set, kernel, self.spins_in, self.spins_out)
def swsft_forward_spins_channels(self, sphere_set, spins): """Applies swsft_forward() to multiple stacked spins and channels. Args: sphere_set: An (n, n, n_spins, n_channels) array representing a spherical functions. Equirectangular sampling, leading dimensions are lat, long. spins: An (n_spins,) list of int spin weights. Returns: An (n//2, n-1, n_spins, n_channels) complex64 array of coefficients. """ for spin in spins: if not self.validate(resolution=sphere_set.shape[0], spin=spin): raise ValueError("Constants are invalid for given input!") ell_max = sphere_utils.ell_max_from_resolution(sphere_set.shape[0]) expanded_spins = jnp.expand_dims(jnp.array(spins), [0, 1, 3]) Inm = jnp.fft.fftshift( # pylint: disable=invalid-name self._compute_Inm(sphere_set, expanded_spins), axes=(0, 1)) deltas = self._slice_wigner_deltas(ell_max, include_negative_m=True) deltas_s = jnp.stack( [deltas[Ellipsis, ell_max - spin] for spin in spins], axis=-1) forward_constants = jnp.stack( [self._slice_forward_constants(ell_max, spin) for spin in spins], axis=-1) return jnp.einsum("lms,lnm,lns,nms...->lms...", forward_constants, deltas, deltas_s, Inm)
def swsft_forward(self, sphere, spin): """Spin-weighted spherical harmonics transform (fast JAX version). Returns coefficients in zero-padded format: [0 0 c00 0 0 ] [0 c1m1 c10 c11 0 ] [c2m2 c2m1 c20 c21 c22] See also: np_spin_spherical_harmonics.swsft_forward_naive(). Args: sphere: A (n, n) array representing a spherical function. Equirectangular sampling, lat, long order. spin: Spin weight. Returns: A (n//2, n-1) array of complex64 coefficients. The coefficient at degree ell and order m is at position [ell, ell_max+m]. """ # This version uses more operations overall but is usually faster # on TPU due to less overhead. if not self.validate(resolution=sphere.shape[0], spin=spin): raise ValueError("Constants are invalid for given input!") ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) Inm = jnp.fft.fftshift(self._compute_Inm(sphere, spin)) # pylint: disable=invalid-name deltas = self._slice_wigner_deltas(ell_max, include_negative_m=True) deltas = deltas * deltas[Ellipsis, ell_max - spin][Ellipsis, None] forward_constants = self._slice_forward_constants(ell_max, spin) return jnp.einsum("ik,ijk,jk->ik", forward_constants, deltas, Inm)
def _compute_Inm(self, sphere, spin): # pylint: disable=invalid-name """See np_spin_spherical_harmonics._compute_Inm().""" ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) coeffs = self._extend_sphere_fft(sphere, spin) rows1 = jnp.concatenate([coeffs[:ell_max + 1, :ell_max + 1], coeffs[:ell_max + 1, -ell_max:]], axis=1) rows2 = jnp.concatenate([coeffs[-ell_max:, :ell_max + 1], coeffs[-ell_max:, -ell_max:]], axis=1) return jnp.concatenate([rows1, rows2], axis=0)
def _compute_Jnm(self, sphere, spin): # pylint: disable=invalid-name """See np_spin_spherical_harmonics._compute_Jnm().""" ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) Inm = self._compute_Inm(sphere, spin) # pylint: disable=invalid-name # Make a matrix with (-1)^{m+s} columns. m = jnp.concatenate( [jnp.arange(ell_max + 1), -jnp.arange(ell_max + 1)[1:][::-1]])[None] signs = (-1.)**(m + spin) # Jnm only contains positive n. Jnm = Inm[:ell_max + 1] # pylint: disable=invalid-name # Make n = -n rowwise. return Jnm.at[1:].add(signs * Inm[-ell_max:][::-1])
def _compute_Jnm_spins_channels(self, sphere_set, spins): # pylint: disable=invalid-name """Computes Jnm over different spins and channels.""" ell_max = sphere_utils.ell_max_from_resolution(sphere_set.shape[0]) expanded_spins = jnp.expand_dims(jnp.array(spins), [0, 1, 3]) Inm = self._compute_Inm(sphere_set, expanded_spins) # pylint: disable=invalid-name # Make a matrix with (-1)^{m+s} columns. m = jnp.concatenate( [jnp.arange(ell_max + 1), -jnp.arange(ell_max + 1)[1:][::-1]]) signs = (-1.)**(jnp.expand_dims(m, (0, 2, 3)) + expanded_spins) # Jnm only contains positive n. Jnm = Inm[:ell_max + 1] # pylint: disable=invalid-name # Make n = -n rowwise. return Jnm.at[1:].add(signs * Inm[-ell_max:][::-1])
def swsft_forward_with_symmetry(self, sphere, spin): """Same as swsft, but with the intermediate Jnm computation.""" # This version uses less operations overall but is usually slower # on TPU due to more overhead. if not self.validate(resolution=sphere.shape[0], spin=spin): raise ValueError("Constants are invalid for given input!") ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) Jnm = self._compute_Jnm(sphere, spin) # pylint: disable=invalid-name Jnm = jnp.concatenate([Jnm[:, -ell_max:], Jnm[:, :ell_max + 1]], axis=1) # pylint: disable=invalid-name deltas = self._slice_wigner_deltas(ell_max, include_negative_m=False) deltas = deltas * deltas[Ellipsis, ell_max - spin][Ellipsis, None] forward_constants = self._slice_forward_constants(ell_max, spin) return jnp.einsum("ik,ijk,jk->ik", forward_constants, deltas, Jnm)
def get_spin_spherical(transformer, shape, spins): """Returns set of spin-weighted spherical functions. Args: transformer: SpinSphericalFourierTransformer instance. shape: Desired shape (batch, latitude, longitude, spins, channels). spins: Desired spins. Returns: Array of spherical functions and array of their spectral coefficients. """ # Make some arbitrary reproducible complex inputs. batch_size, resolution, _, num_spins, num_channels = shape if len(spins) != num_spins: raise ValueError('len(spins) must match desired shape.') ell_max = sphere_utils.ell_max_from_resolution(resolution) num_coefficients = np_spin_spherical_harmonics.n_coeffs_from_ell_max( ell_max) shape_coefficients = (batch_size, num_spins, num_channels, num_coefficients) # These numbers are chosen arbitrarily, but not randomly, since random # coefficients make for hard to visually interpret functions. Something # simpler like linspace(-1-1j, 1+1j) would have the same phase for all complex # numbers, which is also undesirable. coefficients = (jnp.linspace( -0.5, 0.7 + 0.5j, np.prod(shape_coefficients)).reshape(shape_coefficients)) # Broadcast to_matrix = jnp.vectorize(spin_spherical_harmonics.coefficients_to_matrix, signature='(i)->(j,k)') coefficients = to_matrix(coefficients) # Transpose back to (batch, ell, m, spin, channel) format. coefficients = jnp.transpose(coefficients, (0, 3, 4, 1, 2)) # Coefficients for ell < |spin| are always zero. for i, spin in enumerate(spins): coefficients = coefficients.at[:, :abs(spin), :, i].set(0.0) # Convert to spatial domain. batched_backward_transform = jax.vmap( transformer.swsft_backward_spins_channels, in_axes=(0, None)) sphere = batched_backward_transform(coefficients, spins) return sphere, coefficients
def swsft_forward_spins_channels_with_symmetry(self, sphere_set, spins): """Same as `swsft_forward_spins_channels`, but leveraging symmetry.""" for spin in spins: if not self.validate(resolution=sphere_set.shape[0], spin=spin): raise ValueError("Constants are invalid for given input!") ell_max = sphere_utils.ell_max_from_resolution(sphere_set.shape[0]) Jnm = self._compute_Jnm_spins_channels(sphere_set, spins) # pylint: disable=invalid-name Jnm = jnp.concatenate([Jnm[:, -ell_max:], Jnm[:, :ell_max + 1]], axis=1) # pylint: disable=invalid-name deltas = self._slice_wigner_deltas(ell_max, include_negative_m=False) deltas_s = jnp.stack( [deltas[Ellipsis, ell_max - spin] for spin in spins], axis=-1) forward_constants = jnp.stack( [self._slice_forward_constants(ell_max, spin) for spin in spins], axis=-1) return jnp.einsum("lms,lnm,lns,nms...->lms...", forward_constants, deltas, deltas_s, Jnm)
def swsft_forward_naive(sphere, spin): """Spin-weighted spherical harmonics transform (forward). This is a naive and slow implementation but useful for testing; computing multiple coefficients in a vectorized fashion is much faster. Args: sphere: A (n, n) array representing a spherical function with equirectangular sampling, lat, long order. spin: Spin weight (int). Returns: A ((n/2)**2,) Array of complex128 coefficients sorted by increasing degrees. Coefficient of degree ell and order m is at position ell**2 + m + ell. """ ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0]) coeffs = [] for ell in range(ell_max + 1): for m in range(-ell, ell + 1): coeffs.append(_swsft_forward_single(sphere, spin, ell, m)) return np.array(coeffs)