Ejemplo n.º 1
0
def _gen_recurrence_mask(
        l_max: int,
        is_normalized: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Generates mask for recurrence relation on the remaining entries.

  The remaining entries are with respect to the diagonal and offdiagonal
  entries.

  Args:
    l_max: see `gen_normalized_legendre`.
    is_normalized: True if the recurrence mask is used by normalized associated
      Legendre functions.

  Returns:
    Arrays representing the mask used by the recurrence relations.
  """

    # Computes all coefficients.
    m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1]
    if is_normalized:
        c0 = l_mat * l_mat
        c1 = m_mat * m_mat
        c2 = 2.0 * l_mat
        c3 = (l_mat - 1.0) * (l_mat - 1.0)
        d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
        d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
    else:
        d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
        d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)

    d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
    d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
    d_zeros = jnp.zeros((l_max + 1, l_max + 1))
    d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
    d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])

    # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
    # i = jnp.arange(l_max + 1)[:, None, None]
    # j = jnp.arange(l_max + 1)[None, :, None]
    # k = jnp.arange(l_max + 1)[None, None, :]
    i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
    mask = 1.0 * (i + j - k == 0)

    d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
    d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)

    return (d0_mask_3d, d1_mask_3d)
Ejemplo n.º 2
0
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray,
                     is_normalized: bool) -> jnp.ndarray:
    """Generates derivatives of associated Legendre functions of the first kind.

  Args:
    p: The 3D array containing the values of associated Legendre functions; the
      dimensions are in the sequence of order (m), degree (l), and evalution
      points.
    x: A vector of type `float32` or `float64` containing the sampled points.
    is_normalized: True if the associated Legendre functions are normalized.
  Returns:
    The 3D array representing the derivatives of associated Legendre functions
    of the first kind.
  """

    num_m, num_l, num_x = p.shape

    # p_{l-1}^m.
    p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]

    # p_{l-1}^{m+2}.
    p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]

    # p_{l-1}^{m-2}.
    p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]

    # Derivative computation requires negative orders.
    if is_normalized:
        raise NotImplementedError(
            'Negative orders for normalization is not implemented yet.')
    else:
        if num_l > 1:
            l_vec = jnp.arange(1, num_l - 1)
            p_p1 = p[1, 1:num_l - 1, :]
            coeff = -1.0 / ((l_vec + 1) * l_vec)
            update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
            p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)

        if num_l > 2:
            l_vec = jnp.arange(2, num_l - 1)
            p_p2 = p[2, 2:num_l - 1, :]
            coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
            update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
            p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)

    m_mat, l_mat = jnp.mgrid[:num_m, :num_l]

    coeff_zeros = jnp.zeros((num_m, num_l))
    upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
    zero_vec = jnp.zeros((num_l, ))

    a0 = -0.5 / (m_mat - 1.0)
    a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
    a0_masked = a0_masked.at[1, :].set(zero_vec)

    b0 = l_mat + m_mat
    c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
    c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
    c0_masked = c0_masked.at[1, :].set(zero_vec)

    # p_l^{m-1}.
    p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
               jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))

    d0 = -0.5 / (m_mat + 1.0)
    d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
    e0 = d0 * b0 * (b0 + 1.0)
    e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])

    # p_l^{m+1}.
    p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
               jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))

    f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
    f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
    p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked,
                              p_mm1_l) - 0.5 * p_mp1_l

    # Special treatment of the singularity at m = 1.
    if num_m > 1:
        l_vec = jnp.arange(num_l)
        g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
        if num_l > 2:
            g0 = g0 - p[2, :, :]
        p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
        p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
        p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, )))

    return p_derivative