def swsft_backward_naive(coeffs, spin):
    """Spin-weighted spherical harmonics transform (backward).

  This is a naive and slow implementation, due to the non-vectorized Gnm
  computation.

  Args:
    coeffs: List of n**2 SWSFT coefficients.
    spin: Spin weight (int).

  Returns:
    A complex128 matrix corresponding to the (H&W) equiangular sampling of a
    spin-weighted spherical function.
  """
    ell_max = sphere_utils.ell_max_from_n_coeffs(len(coeffs))
    # Gnm is related to the 2D Fourier transform of the desired output.
    ft = _compute_Gnm_naive(coeffs, spin)
    # Rearrange order for ifft2.
    ft = np.fft.ifftshift(ft)
    # We insert zero rows and columns to ensure the IFFT output matches the
    # swsft_backward_naive() input resolution.
    ft = np.concatenate([
        ft[:, :ell_max + 1],
        np.zeros((2 * ell_max + 1, 1)), ft[:, ell_max + 1:]
    ],
                        axis=1)
    ft = np.concatenate(
        [ft[:ell_max + 1],
         np.zeros_like(ft), ft[ell_max + 1:]])

    return np.fft.ifft2(ft)[:2 * (ell_max + 1)] * ft.size
def _compute_Gnm_naive(coeffs, spin):  # pylint: disable=invalid-name
    r"""Compute Gnm (not vectorized).

  The matrix Gnm, defined in H&W, Equation (13), is closely related to the 2D
  Fourier transform of a spin-weighted spherical function.

  Gnm = (-1)^s i^(m+s) \sum_\ell c \Delta_{-n,-s} \Delta_{-n,m} _sa_m^\ell,
  where c = \sqrt{(2\ell + 1) / (4\pi)}, and _sa_m^\ell is the coefficient at
  (ell, m).

  Args:
    coeffs: See swsft_backward_naive().
    spin: See swsft_backward_naive().

  Returns:
    The complex128 matrix Gnm. If coeffs has n**2 elements, the output is
    (2*n-1, 2*n-1).

  Raises:
    ValueError: If len(coeffs) is not a perfect square.
  """
    ell_max = sphere_utils.ell_max_from_n_coeffs(len(coeffs))
    deltas = sphere_utils.compute_all_wigner_delta(ell_max)
    # TODO(machc): This could be simplified. Previously we only stored
    # non-negative n in the Wigner Deltas, and used symmetries to
    # complete the result. Now that the complete Deltas are stored, we
    # could simplify the code below, but for now we just revert `deltas`
    # to what was stored before and keep using the prior implementation.
    deltas = tuple([delta[ell:] for ell, delta in enumerate(deltas)])
    Gnm = np.zeros((ell_max + 1, 2 * ell_max + 1), dtype=np.complex128)  # pylint: disable=invalid-name
    for ell in range(abs(spin), ell_max + 1):
        factor = np.sqrt((2 * ell + 1) / 4 / np.pi)
        for m in range(-ell, ell + 1):
            # The following also fixes the signs because deltas should be evaluated at
            # negative n but we only store values for positive n.
            phase = (1j)**(m + spin) * (-1)**m
            index = _get_swsft_coeff_index(ell, m)
            Gnm[:ell + 1,
                ell_max + m] += (phase * factor * deltas[ell][:, ell - spin] *
                                 deltas[ell][:, ell + m] * coeffs[index])
    # Use symmetry to obtain entries for negative n.
    signs = (-1.)**(spin + np.arange(-ell_max, ell_max + 1))[None, :]
    return np.concatenate([signs * Gnm[1:][::-1], Gnm])
def coefficients_to_matrix(coeffs):
  """Converts 1D array of coefficients to a 2D array, padding with zeros.

  For input [c00, c1m1, c10, c11, c2m2, c2m1, c20, c21, c22], this returns:
  [0    0    c00 0   0  ]
  [0    c1m1 c10 c11 0  ]
  [c2m2 c2m1 c20 c21 c22]

  Args:
    coeffs: List of n**2 SWSFT coefficients.

  Returns:
    A (n, 2n-1) array with one coefficient in the center of the first row, three
    in the second row, etcetera (see example above) and padded with zeros.
  """
  ell_max = sphere_utils.ell_max_from_n_coeffs(len(coeffs))
  matrix = []
  for ell in range(ell_max + 1):
    matrix.append(jnp.pad(coeffs[ell**2:(ell+1)**2],
                          ((ell_max - ell, ell_max - ell))))

  return jnp.stack(matrix)