def swsft_forward(self, sphere, spin):
        """Spin-weighted spherical harmonics transform (fast JAX version).

    Returns coefficients in zero-padded format:
    [0    0    c00 0   0  ]
    [0    c1m1 c10 c11 0  ]
    [c2m2 c2m1 c20 c21 c22]

    See also: np_spin_spherical_harmonics.swsft_forward_naive().

    Args:
      sphere: A (n, n) array representing a spherical function. Equirectangular
        sampling, lat, long order.
      spin: Spin weight.

    Returns:
      A (n//2, n-1) array of complex64 coefficients. The coefficient at degree
      ell and order m is at position [ell, ell_max+m].
    """
        if not self.validate(resolution=sphere.shape[0], spin=spin):
            raise ValueError("Constants are invalid for given input!")

        ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])
        Jnm = self._compute_Jnm(sphere, spin)  # pylint: disable=invalid-name
        Jnm = jnp.concatenate([Jnm[:, -ell_max:], Jnm[:, :ell_max + 1]],
                              axis=1)  # pylint: disable=invalid-name
        deltas = self._slice_wigner_deltas(ell_max)
        deltas = deltas * deltas[Ellipsis, ell_max - spin][Ellipsis, None]
        forward_constants = self._slice_forward_constants(ell_max, spin)

        return jnp.einsum("ik,ijk,jk->ik", forward_constants, deltas, Jnm)
    def _compute_constants(self, resolutions, spins):
        """Computes constants (class attributes). See constructor docstring."""
        ells = [
            sphere_utils.ell_max_from_resolution(res) for res in resolutions
        ]
        ell_max = max(ells)
        wigner_deltas = sphere_utils.compute_all_wigner_delta(ell_max)
        padded_deltas = []
        for ell, delta in enumerate(wigner_deltas):
            padded_deltas.append(
                jnp.pad(delta,
                        ((0, ell_max - ell), (ell_max - ell, ell_max - ell))))
        self.wigner_deltas = jnp.stack(padded_deltas)

        self.quadrature_weights = {
            res: jnp.array(sphere_utils.torus_quadrature_weights(res))
            for res in resolutions
        }

        self.swsft_forward_constants = {}
        for spin in spins:
            constants_spin = []
            for ell in range(ell_max + 1):
                k_ell = sphere_utils.swsft_forward_constant(
                    spin, ell, jnp.arange(-ell, ell + 1))
                k_ell = jnp.asarray(k_ell)
                constants_spin.append(k_ell)
            self.swsft_forward_constants[spin] = coefficients_to_matrix(
                jnp.concatenate(constants_spin))
def _compute_Inm(sphere, spin):  # pylint: disable=invalid-name
    r"""Computes Inm.

  This evaluates
  Inm = \int_{S^2}e^{-in\theta}e^{-im\phi} f(\theta, \phi) sin\theta d\theta
  d\phi, as defined in H&W, Equation (8).
  It is used in intermediate steps of the SWSFT computation.

  Args:
    sphere: See swsft_forward_naive().
    spin: See swsft_forward_naive().

  Returns:
    The complex128 matrix Inm. If sphere is (n, n), output will be (n-1, n-1).

  """
    ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])
    coeffs = _extend_sphere_fft(sphere, spin=spin)

    rows1 = np.concatenate(
        [coeffs[:ell_max + 1, :ell_max + 1], coeffs[:ell_max + 1, -ell_max:]],
        axis=1)
    rows2 = np.concatenate(
        [coeffs[-ell_max:, :ell_max + 1], coeffs[-ell_max:, -ell_max:]],
        axis=1)

    return np.concatenate([rows1, rows2], axis=0)
def _compute_Jnm(sphere, spin):  # pylint: disable=invalid-name
    """Computes Jnm (trimmed version of Inm).

  Jnm = I0m for n=0
      = Inm + (-1)^{m+s}I_(-n)m for n>0

  This matrix is defined in H&W, Equation (10).

  Args:
    sphere: See swsft_forward_naive().
    spin: See swsft_forward_naive().

  Returns:
    The complex128 matrix Jnm. If sphere is (n, n), output will be (n//2, n-1).

  Raises:
    ValueError: If sphere rank is not 2.
  """
    if sphere.ndim != 2:
        raise ValueError("Input sphere rank must be 2.")
    ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])

    Inm = _compute_Inm(sphere, spin=spin)  # pylint: disable=invalid-name

    # now make a matrix with (-1)^{m+s} columns
    m = np.concatenate(
        [np.arange(ell_max + 1), -np.arange(ell_max + 1)[1:][::-1]])[None]
    signs = (-1.)**(m + spin)

    # only takes positive n
    Jnm = Inm[:ell_max + 1].copy()  # pylint: disable=invalid-name
    # make n = -n rowwise
    Jnm[1:] += signs * Inm[-ell_max:][::-1]

    return Jnm
Esempio n. 5
0
    def __call__(self, sphere_set):
        """Applies convolution to inputs.

    Args:
      sphere_set: A (batch_size, resolution, resolution, n_spins_in,
        n_channels_in) array of spin-weighted spherical functions (SWSF) with
        equiangular sampling.

    Returns:
      A (batch_size, resolution, resolution, n_spins_out, n_channels_out)
      complex64 array of SWSF with equiangular H&W sampling.
    """
        resolution = sphere_set.shape[1]
        if sphere_set.shape[2] != resolution:
            raise ValueError("Axes 1 and 2 must have the same dimensions!")
        if sphere_set.shape[3] != len(list(self.spins_in)):
            raise ValueError("Input axis 3 (spins_in) doesn't match layer's.")

        # Make sure constants contain all spins for input resolution.
        for spin in set(self.spins_in).union(self.spins_out):
            if not self.transformer.validate(resolution, spin):
                raise ValueError("Constants are invalid for given input!")

        ell_max = sphere_utils.ell_max_from_resolution(resolution)
        num_channels_in = sphere_set.shape[-1]
        if self.num_filter_params is None:
            kernel = self._get_kernel(ell_max, num_channels_in)
        else:
            kernel = self._get_localized_kernel(ell_max, num_channels_in)

        # Map over the batch dimension.
        vmap_convolution = jax.vmap(_swsconv_spatial_spectral,
                                    in_axes=(None, 0, None, None, None))
        return vmap_convolution(self.transformer, sphere_set, kernel,
                                self.spins_in, self.spins_out)
    def swsft_forward_spins_channels(self, sphere_set, spins):
        """Applies swsft_forward() to multiple stacked spins and channels.

    Args:
      sphere_set: An (n, n, n_spins, n_channels) array representing a spherical
        functions. Equirectangular sampling, leading dimensions are lat, long.
      spins: An (n_spins,) list of int spin weights.

    Returns:
      An (n//2, n-1, n_spins, n_channels) complex64 array of coefficients.
    """
        for spin in spins:
            if not self.validate(resolution=sphere_set.shape[0], spin=spin):
                raise ValueError("Constants are invalid for given input!")

        ell_max = sphere_utils.ell_max_from_resolution(sphere_set.shape[0])
        expanded_spins = jnp.expand_dims(jnp.array(spins), [0, 1, 3])
        Inm = jnp.fft.fftshift(  # pylint: disable=invalid-name
            self._compute_Inm(sphere_set, expanded_spins),
            axes=(0, 1))

        deltas = self._slice_wigner_deltas(ell_max, include_negative_m=True)
        deltas_s = jnp.stack(
            [deltas[Ellipsis, ell_max - spin] for spin in spins], axis=-1)
        forward_constants = jnp.stack(
            [self._slice_forward_constants(ell_max, spin) for spin in spins],
            axis=-1)

        return jnp.einsum("lms,lnm,lns,nms...->lms...", forward_constants,
                          deltas, deltas_s, Inm)
Esempio n. 7
0
    def swsft_forward(self, sphere, spin):
        """Spin-weighted spherical harmonics transform (fast JAX version).

    Returns coefficients in zero-padded format:
    [0    0    c00 0   0  ]
    [0    c1m1 c10 c11 0  ]
    [c2m2 c2m1 c20 c21 c22]

    See also: np_spin_spherical_harmonics.swsft_forward_naive().

    Args:
      sphere: A (n, n) array representing a spherical function. Equirectangular
        sampling, lat, long order.
      spin: Spin weight.

    Returns:
      A (n//2, n-1) array of complex64 coefficients. The coefficient at degree
      ell and order m is at position [ell, ell_max+m].
    """
        # This version uses more operations overall but is usually faster
        # on TPU due to less overhead.
        if not self.validate(resolution=sphere.shape[0], spin=spin):
            raise ValueError("Constants are invalid for given input!")

        ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])
        Inm = jnp.fft.fftshift(self._compute_Inm(sphere, spin))  # pylint: disable=invalid-name

        deltas = self._slice_wigner_deltas(ell_max, include_negative_m=True)
        deltas = deltas * deltas[Ellipsis, ell_max - spin][Ellipsis, None]
        forward_constants = self._slice_forward_constants(ell_max, spin)

        return jnp.einsum("ik,ijk,jk->ik", forward_constants, deltas, Inm)
  def _compute_Inm(self, sphere, spin):  # pylint: disable=invalid-name
    """See np_spin_spherical_harmonics._compute_Inm()."""
    ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])
    coeffs = self._extend_sphere_fft(sphere, spin)

    rows1 = jnp.concatenate([coeffs[:ell_max + 1, :ell_max + 1],
                             coeffs[:ell_max + 1, -ell_max:]],
                            axis=1)
    rows2 = jnp.concatenate([coeffs[-ell_max:, :ell_max + 1],
                             coeffs[-ell_max:, -ell_max:]],
                            axis=1)

    return jnp.concatenate([rows1, rows2], axis=0)
    def _compute_Jnm(self, sphere, spin):  # pylint: disable=invalid-name
        """See np_spin_spherical_harmonics._compute_Jnm()."""
        ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])
        Inm = self._compute_Inm(sphere, spin)  # pylint: disable=invalid-name

        # Make a matrix with (-1)^{m+s} columns.
        m = jnp.concatenate(
            [jnp.arange(ell_max + 1),
             -jnp.arange(ell_max + 1)[1:][::-1]])[None]
        signs = (-1.)**(m + spin)

        # Jnm only contains positive n.
        Jnm = Inm[:ell_max + 1]  # pylint: disable=invalid-name
        # Make n = -n rowwise.
        return Jnm.at[1:].add(signs * Inm[-ell_max:][::-1])
    def _compute_Jnm_spins_channels(self, sphere_set, spins):  # pylint: disable=invalid-name
        """Computes Jnm over different spins and channels."""
        ell_max = sphere_utils.ell_max_from_resolution(sphere_set.shape[0])
        expanded_spins = jnp.expand_dims(jnp.array(spins), [0, 1, 3])
        Inm = self._compute_Inm(sphere_set, expanded_spins)  # pylint: disable=invalid-name

        # Make a matrix with (-1)^{m+s} columns.
        m = jnp.concatenate(
            [jnp.arange(ell_max + 1), -jnp.arange(ell_max + 1)[1:][::-1]])
        signs = (-1.)**(jnp.expand_dims(m, (0, 2, 3)) + expanded_spins)

        # Jnm only contains positive n.
        Jnm = Inm[:ell_max + 1]  # pylint: disable=invalid-name
        # Make n = -n rowwise.
        return Jnm.at[1:].add(signs * Inm[-ell_max:][::-1])
    def swsft_forward_with_symmetry(self, sphere, spin):
        """Same as swsft, but with the intermediate Jnm computation."""
        # This version uses less operations overall but is usually slower
        # on TPU due to more overhead.
        if not self.validate(resolution=sphere.shape[0], spin=spin):
            raise ValueError("Constants are invalid for given input!")

        ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])
        Jnm = self._compute_Jnm(sphere, spin)  # pylint: disable=invalid-name
        Jnm = jnp.concatenate([Jnm[:, -ell_max:], Jnm[:, :ell_max + 1]],
                              axis=1)  # pylint: disable=invalid-name
        deltas = self._slice_wigner_deltas(ell_max, include_negative_m=False)
        deltas = deltas * deltas[Ellipsis, ell_max - spin][Ellipsis, None]
        forward_constants = self._slice_forward_constants(ell_max, spin)

        return jnp.einsum("ik,ijk,jk->ik", forward_constants, deltas, Jnm)
Esempio n. 12
0
def get_spin_spherical(transformer, shape, spins):
    """Returns set of spin-weighted spherical functions.

  Args:
    transformer: SpinSphericalFourierTransformer instance.
    shape: Desired shape (batch, latitude, longitude, spins, channels).
    spins: Desired spins.

  Returns:
    Array of spherical functions and array of their spectral coefficients.
  """
    # Make some arbitrary reproducible complex inputs.
    batch_size, resolution, _, num_spins, num_channels = shape
    if len(spins) != num_spins:
        raise ValueError('len(spins) must match desired shape.')
    ell_max = sphere_utils.ell_max_from_resolution(resolution)
    num_coefficients = np_spin_spherical_harmonics.n_coeffs_from_ell_max(
        ell_max)
    shape_coefficients = (batch_size, num_spins, num_channels,
                          num_coefficients)
    # These numbers are chosen arbitrarily, but not randomly, since random
    # coefficients make for hard to visually interpret functions. Something
    # simpler like linspace(-1-1j, 1+1j) would have the same phase for all complex
    # numbers, which is also undesirable.
    coefficients = (jnp.linspace(
        -0.5, 0.7 + 0.5j,
        np.prod(shape_coefficients)).reshape(shape_coefficients))

    # Broadcast
    to_matrix = jnp.vectorize(spin_spherical_harmonics.coefficients_to_matrix,
                              signature='(i)->(j,k)')
    coefficients = to_matrix(coefficients)
    # Transpose back to (batch, ell, m, spin, channel) format.
    coefficients = jnp.transpose(coefficients, (0, 3, 4, 1, 2))

    # Coefficients for ell < |spin| are always zero.
    for i, spin in enumerate(spins):
        coefficients = coefficients.at[:, :abs(spin), :, i].set(0.0)

    # Convert to spatial domain.
    batched_backward_transform = jax.vmap(
        transformer.swsft_backward_spins_channels, in_axes=(0, None))
    sphere = batched_backward_transform(coefficients, spins)

    return sphere, coefficients
    def swsft_forward_spins_channels_with_symmetry(self, sphere_set, spins):
        """Same as `swsft_forward_spins_channels`, but leveraging symmetry."""
        for spin in spins:
            if not self.validate(resolution=sphere_set.shape[0], spin=spin):
                raise ValueError("Constants are invalid for given input!")

        ell_max = sphere_utils.ell_max_from_resolution(sphere_set.shape[0])
        Jnm = self._compute_Jnm_spins_channels(sphere_set, spins)  # pylint: disable=invalid-name
        Jnm = jnp.concatenate([Jnm[:, -ell_max:], Jnm[:, :ell_max + 1]],
                              axis=1)  # pylint: disable=invalid-name

        deltas = self._slice_wigner_deltas(ell_max, include_negative_m=False)
        deltas_s = jnp.stack(
            [deltas[Ellipsis, ell_max - spin] for spin in spins], axis=-1)
        forward_constants = jnp.stack(
            [self._slice_forward_constants(ell_max, spin) for spin in spins],
            axis=-1)

        return jnp.einsum("lms,lnm,lns,nms...->lms...", forward_constants,
                          deltas, deltas_s, Jnm)
def swsft_forward_naive(sphere, spin):
    """Spin-weighted spherical harmonics transform (forward).

  This is a naive and slow implementation but useful for testing; computing
  multiple coefficients in a vectorized fashion is much faster.

  Args:
    sphere: A (n, n) array representing a spherical function with
      equirectangular sampling, lat, long order.
    spin: Spin weight (int).

  Returns:
    A ((n/2)**2,) Array of complex128 coefficients sorted by increasing
    degrees. Coefficient of degree ell and order m is at position ell**2 + m +
    ell.
  """
    ell_max = sphere_utils.ell_max_from_resolution(sphere.shape[0])
    coeffs = []
    for ell in range(ell_max + 1):
        for m in range(-ell, ell + 1):
            coeffs.append(_swsft_forward_single(sphere, spin, ell, m))

    return np.array(coeffs)