Beispiel #1
0
def cumprod(x, axis=0, exclusive=False):
    if exclusive:
        x = _jnp.swapaxes(x, axis, -1)
        x = _jnp.concatenate((_jnp.ones_like(x[..., -1:]), x[..., :-1]), -1)
        res = _jnp.cumprod(x, -1)
        return _jnp.swapaxes(res, axis, -1)
    return _jnp.cumprod(x, axis)
Beispiel #2
0
def zeta(x, q=None):
    assert q is not None, "Riemann zeta function is not implemented yet."
    # Reference: Johansson, Fredrik.
    # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives."
    # Numerical Algorithms 69.2 (2015): 253-270.
    # https://arxiv.org/abs/1309.2877 - formula (5)
    # here we keep the same notation as in reference
    s, a = _promote_args_inexact("zeta", x, q)
    dtype = lax.dtype(a).type
    s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
    # precision ~ N, M
    N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16)
    assert M <= len(_BERNOULLI_COEFS)
    k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim)))
    S = jnp.sum((a_ + k)**-s_, -1)
    I = lax.div((a + N)**(dtype(1) - s), s - dtype(1))
    T0 = (a + N)**-s
    m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
    s_over_a = (s_ + m) / (a_ + N)
    T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
    T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
    coefs = np.expand_dims(
        np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
        tuple(range(a.ndim)))
    T1 = T1 / coefs
    T = T0 * (dtype(0.5) + T1.sum(-1))
    return S + I + T
Beispiel #3
0
def volumetric_rendering(raw,
                         z_vals,
                         dirs,
                         use_white_background,
                         sigma_activation=nn.relu,
                         sample_at_infinity=True,
                         eps=1e-10):
    """Volumetric Rendering Function.

  Args:
    raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4].
    z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
    dirs: jnp.ndarray(float32), [batch_size, 3].
    use_white_background: bool.
    sigma_activation: the activation functions to apply to the sigma values.
    sample_at_infinity: if True adds a sample at infinity.
    eps: a small number to prevent numerical issues.

  Returns:
    rgb: jnp.ndarray(float32), [batch_size, 3].
    depth: jnp.ndarray(float32), [batch_size].
    acc: jnp.ndarray(float32), [batch_size].
    weights: jnp.ndarray(float32), [batch_size, num_coarse_samples]
  """
    rgb = nn.sigmoid(raw['rgb'])
    sigma = sigma_activation(jnp.squeeze(raw['alpha'], axis=-1))
    # TODO(keunhong): remove this hack.
    last_sample_z = 1e10 if sample_at_infinity else 1e-19
    dists = jnp.concatenate([
        z_vals[..., 1:] - z_vals[..., :-1],
        jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
    ], -1)
    dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
    alpha = 1.0 - jnp.exp(-sigma * dists)
    accum_prod = jnp.concatenate([
        jnp.full_like(alpha[..., :1], 1., alpha.dtype),
        jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1)
    ],
                                 axis=-1)
    weights = alpha * accum_prod

    rgb = (weights[..., None] * rgb).sum(axis=-2)
    exp_depth = (weights * z_vals).sum(axis=-1)
    med_depth = compute_depth_map(weights, z_vals)
    acc = weights.sum(axis=-1)
    if use_white_background:
        rgb = rgb + (1. - acc[..., None])

    inv_eps = 1.0 / eps
    disp = 1.0 / exp_depth
    disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp,
                     inv_eps)

    if sample_at_infinity:
        acc = weights[..., :-1].sum(axis=-1)

    return rgb, exp_depth, med_depth, disp, acc, weights
Beispiel #4
0
    def __call__(self, x):
        # transform to (-1, 1) interval
        t = jnp.tanh(x)

        # apply stick-breaking transform
        remainder = jnp.cumprod(1 - jnp.abs(t[..., :-1]), axis=-1)
        pad_width = [(0, 0)] * (t.ndim - 1) + [(1, 0)]
        remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0)
        return t * remainder
Beispiel #5
0
def _compute_hash_constants(spatial_dimension, cells_per_side):
  if cells_per_side.size == 1:
    return np.array([[
        cells_per_side ** d for d in range(spatial_dimension)]], dtype=np.int64)
  elif cells_per_side.size == spatial_dimension:
    one = np.array([[1]], dtype=np.int32)
    cells_per_side = np.concatenate((one, cells_per_side[:, :-1]), axis=1)
    return np.array(np.cumprod(cells_per_side), dtype=np.int64)
  else:
    raise ValueError()
Beispiel #6
0
 def __call__(self, x):
     # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K)
     x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
     # convert to probabilities (relative to the remaining) of each fraction of the stick
     z = _clipped_expit(x)
     z1m_cumprod = jnp.cumprod(1 - z, axis=-1)
     pad_width = [(0, 0)] * x.ndim
     pad_width[-1] = (0, 1)
     z_padded = jnp.pad(z, pad_width, mode="constant", constant_values=1.)
     pad_width = [(0, 0)] * x.ndim
     pad_width[-1] = (1, 0)
     z1m_cumprod_shifted = jnp.pad(z1m_cumprod, pad_width, mode="constant", constant_values=1.)
     return z_padded * z1m_cumprod_shifted
Beispiel #7
0
def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):
    """Volumetric Rendering Function.

    Args:
      rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
      sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
      z_vals: jnp.ndarray(float32), [batch_size, num_samples].
      dirs: jnp.ndarray(float32), [batch_size, 3].
      white_bkgd: bool.

    Returns:
      comp_rgb: jnp.ndarray(float32), [batch_size, 3].
      disp: jnp.ndarray(float32), [batch_size].
      acc: jnp.ndarray(float32), [batch_size].
      weights: jnp.ndarray(float32), [batch_size, num_samples]
    """
    eps = 1e-10
    dists = jnp.concatenate(
        [
            z_vals[Ellipsis, 1:] - z_vals[Ellipsis, :-1],
            jnp.broadcast_to([1e10], z_vals[Ellipsis, :1].shape),
        ],
        -1,
    )
    dists = dists * jnp.linalg.norm(dirs[Ellipsis, None, :], axis=-1)
    # Note that we're quietly turning sigma from [..., 0] to [...].
    alpha = 1.0 - jnp.exp(-sigma[Ellipsis, 0] * dists)
    accum_prod = jnp.concatenate(
        [
            jnp.ones_like(alpha[Ellipsis, :1], alpha.dtype),
            jnp.cumprod(1.0 - alpha[Ellipsis, :-1] + eps, axis=-1),
        ],
        axis=-1,
    )
    weights = alpha * accum_prod

    comp_rgb = (weights[Ellipsis, None] * rgb).sum(axis=-2)
    depth = (weights * z_vals).sum(axis=-1)
    acc = weights.sum(axis=-1)  # Alpha
    # Equivalent to (but slightly more efficient and stable than):
    #  disp = 1 / max(eps, where(acc > eps, depth / acc, 0))
    inv_eps = 1 / eps
    disp = acc / depth
    disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp,
                     inv_eps)
    if white_bkgd:
        comp_rgb = comp_rgb + (1.0 - acc[Ellipsis, None])
    return comp_rgb, disp, acc, weights
Beispiel #8
0
def rtrun(dtau, S):
    """Radiative Transfer using two-stream approximaion + 2E3 (Helios-R1 type)

    Args:
        dtau: opacity matrix
        S: source matrix [N_layer, N_nus]

    Returns:
        flux in the unit of [erg/cm2/s/cm-1] if using piBarr as a source function.
    """
    Nnus = jnp.shape(dtau)[1]
    TransM = jnp.where(dtau == 0, 1.0, trans2E3(dtau))
    Qv = jnp.vstack([(1 - TransM) * S, jnp.zeros(Nnus)])
    return jnp.sum(Qv *
                   jnp.cumprod(jnp.vstack([jnp.ones(Nnus), TransM]), axis=0),
                   axis=0)
Beispiel #9
0
def piecewise_interpolate_schedule(
        interpolate_type: str,
        init_value: float,
        boundaries_and_scales: Optional[Dict[int,
                                             float]] = None) -> base.Schedule:
    """Returns a function which implements a piecewise interpolated schedule.

  Args:
    interpolate_type: 'linear' or 'cosine', specifying the interpolation
      strategy.
    init_value: An initial value `init_v`.
    boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling
      factors `f_i`. At boundary step `b_i`, the schedule returns `init_v`
      scaled by the product of all factors `f_j` such that `b_j` < `b_i`. The
      values in between each boundary will be interpolated as per `type`.

  Returns:
    schedule: A function that maps step counts to values.
  """
    if interpolate_type == 'linear':
        interpolate_fn = _linear_interpolate
    elif interpolate_type == 'cosine':
        interpolate_fn = _cosine_interpolate
    else:
        raise ValueError(
            '`interpolate_type` must be either \'cos\' or \'linear\'')

    if boundaries_and_scales:
        boundaries, scales = zip(*sorted(boundaries_and_scales.items()))
        if not all(scale >= 0. for scale in scales):
            raise ValueError(
                '`piecewise_interpolate_schedule` expects non-negative scale factors'
            )
    else:
        boundaries, scales = (), ()

    bounds = jnp.stack((0, ) + boundaries)
    values = jnp.cumprod(jnp.stack((init_value, ) + scales))
    interval_sizes = (bounds[1:] - bounds[:-1])

    def schedule(count):
        indicator = (bounds[:-1] <= count) & (count < bounds[1:])
        pct = (count - bounds[:-1]) / interval_sizes
        interp_vals = interpolate_fn(values[:-1], values[1:], pct)
        return indicator.dot(interp_vals) + (bounds[-1] <= count) * values[-1]

    return schedule
Beispiel #10
0
    def __init__(self, beta_min=0.1, beta_max=20, N=1000):
        """Construct a Variance Preserving SDE.

    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
        super().__init__(N)
        self.beta_0 = beta_min
        self.beta_1 = beta_max
        self.N = N
        self.discrete_betas = jnp.linspace(beta_min / N, beta_max / N, N)
        self.alphas = 1. - self.discrete_betas
        self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
        self.sqrt_alphas_cumprod = jnp.sqrt(self.alphas_cumprod)
        self.sqrt_1m_alphas_cumprod = jnp.sqrt(1. - self.alphas_cumprod)
Beispiel #11
0
def _compute_R(order, factor):
    """
    computes the R matrix with entries
    given by the first equation on page 8 of [1]

    This is used to update the differences matrix when step size h is varied according
    to factor = h_{n+1} / h_n

    Note that the U matrix also defined in the same section can be also be
    found using factor = 1, which corresponds to R with a constant step size
    """
    I = jnp.arange(1, MAX_ORDER + 1).reshape(-1, 1)
    J = jnp.arange(1, MAX_ORDER + 1)
    M = jnp.empty((MAX_ORDER + 1, MAX_ORDER + 1))
    M = jax.ops.index_update(M, jax.ops.index[1:, 1:], (I - 1 - factor * J) / I)
    M = jax.ops.index_update(M, jax.ops.index[0], 1)
    R = jnp.cumprod(M, axis=0)

    return R
Beispiel #12
0
def signed_stick_breaking_tril(t):
    # make sure that t in (-1, 1)
    eps = jnp.finfo(t.dtype).eps
    t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps))
    # transform t to tril matrix with identity diagonal
    r = vec_to_tril_matrix(t, diagonal=-1)

    # apply stick-breaking on the squared values;
    # we omit the step of computing s = z * z_cumprod by using the fact:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    z = r ** 2
    z1m_cumprod_sqrt = jnp.cumprod(jnp.sqrt(1 - z), axis=-1)

    pad_width = [(0, 0)] * z.ndim
    pad_width[-1] = (1, 0)
    z1m_cumprod_sqrt_shifted = jnp.pad(z1m_cumprod_sqrt[..., :-1], pad_width,
                                       mode="constant", constant_values=1.)
    y = (r + jnp.identity(r.shape[-1])) * z1m_cumprod_sqrt_shifted
    return y
Beispiel #13
0
def interpolation_matrix(
        order: Union[np.ndarray, int],
        step_size_ratio: Union[np.ndarray, float, int]) -> np.ndarray:
    """Creates the matrix used to interpolate backward differences."""
    orders = np.arange(1, MAX_ORDER + 1)
    i = orders[:, np.newaxis]
    j = orders[np.newaxis, :]
    # Matrix whose (i, j)-th entry (`1 <= i, j <= order`) is
    # `1/j! (0 - i * step_size_ratio) * ... * ((j-1) - i * step_size_ratio)`.
    full_interpolation_matrix = np.cumprod(((j - 1) - i * step_size_ratio) / j,
                                           axis=1)
    zeros_matrix = np.zeros_like(full_interpolation_matrix)
    interpolation_matrix_ = np.where(
        np.arange(1, MAX_ORDER + 1) <= order,
        np.transpose(
            np.where(
                np.arange(1, MAX_ORDER + 1) <= order,
                np.transpose(full_interpolation_matrix),
                zeros_matrix,
            )),
        zeros_matrix,
    )
    return interpolation_matrix_
Beispiel #14
0
def cumprod(x):
    return np.cumprod(x, axis=-1)
Beispiel #15
0
 def cumop(x, axis=axis, mode=mode):
     if mode == "add":
         return jnp.cumsum(x, axis=axis)
     else:
         return jnp.cumprod(x, axis=axis)
Beispiel #16
0
 def testCumProd(self):
     x = np.arange(9).reshape(3, 3) + 1
     y = vmap(lambda x: np.cumprod(x, axis=-1))(x)
     self.assertAllClose(onp.cumprod(x, axis=1, dtype=np.int_),
                         y,
                         check_dtypes=True)
Beispiel #17
0
def cumprod_exclusive(tensor):
    prod = jnp.roll(jnp.cumprod(tensor, axis=-1), 1, axis=-1)
    return index_update(prod, index[..., 0], 1.0)
Beispiel #18
0
def _cumprod_impl(x):
    return np.cumprod(x, axis=-1)
Beispiel #19
0
def volumetric_rendering(rgb,
                         sigma,
                         z_vals,
                         dirs,
                         use_white_background,
                         sample_at_infinity=True,
                         return_weights=False,
                         eps=1e-10):
    """Volumetric Rendering Function.

  Args:
    rgb: an array of size (B,S,3) containing the RGB color values.
    sigma: an array of size (B,S,1) containing the densities.
    z_vals: an array of size (B,S) containing the z-coordinate of the samples.
    dirs: an array of size (B,3) containing the directions of rays.
    use_white_background: whether to assume a white background or not.
    sample_at_infinity: if True adds a sample at infinity.
    return_weights: if True returns the weights in the dictionary.
    eps: a small number to prevent numerical issues.

  Returns:
    A dictionary containing:
      rgb: an array of size (B,3) containing the rendered colors.
      depth: an array of size (B,) containing the rendered depth.
      acc: an array of size (B,) containing the accumulated density.
      weights: an array of size (B,S) containing the weight of each sample.
  """
    # TODO(keunhong): remove this hack.
    last_sample_z = 1e10 if sample_at_infinity else 1e-19
    dists = jnp.concatenate([
        z_vals[..., 1:] - z_vals[..., :-1],
        jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
    ], -1)
    dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
    alpha = 1.0 - jnp.exp(-sigma * dists)
    # Prepend a 1.0 to make this an 'exclusive' cumprod as in `tf.math.cumprod`.
    accum_prod = jnp.concatenate([
        jnp.ones_like(alpha[..., :1], alpha.dtype),
        jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1),
    ],
                                 axis=-1)
    weights = alpha * accum_prod

    rgb = (weights[..., None] * rgb).sum(axis=-2)
    exp_depth = (weights * z_vals).sum(axis=-1)
    med_depth = compute_depth_map(weights, z_vals)
    acc = weights.sum(axis=-1)
    if use_white_background:
        rgb = rgb + (1. - acc[..., None])

    if sample_at_infinity:
        acc = weights[..., :-1].sum(axis=-1)

    out = {
        'rgb': rgb,
        'depth': exp_depth,
        'med_depth': med_depth,
        'acc': acc,
    }
    if return_weights:
        out['weights'] = weights
    return out
Beispiel #20
0
def spherical_to_cartesian(phi_x):
    r = phi_x[0]
    phi = phi_x[1:]
    return r * jnp.hstack([1.0, jnp.cumprod(jnp.sin(phi))]) * jnp.hstack(
        [jnp.cos(phi), 1.0])
Beispiel #21
0
def cumprod(a, axis=None, dtype=None):
  if isinstance(a, JaxArray): a = a.value
  return JaxArray(jnp.cumprod(a=a, axis=axis, dtype=dtype))
Beispiel #22
0
def raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, rng=None):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: (num_rays, num_samples || num_importance, 4) prediction from model
        z_vals: (num_rays, num_samples || num_importance) integration time
        rays_d: (num_rays, 3) direction of each ray
        raw_noise_std: std of noise added for regularization
        white_bkgd: whether to use the alpha channel for white background
        rng: random key
    Returns:
        acc_map: (num_rays) sum of weights along each ray
        depth_map: (num_rays) estimated distance to object
        disp_map: (num_rays) disparity map (inverse of depth map)
        rgb_map: (num_rays, 3) estimated RGB color of a ray
        weights: (num_rays, num_samples || num_importance) weights assigned to each sampled color
    """

    # compute 'distance' (in time) between each integration time along a ray
    dists = z_vals[..., 1:] - z_vals[..., :-1]

    # the 'distance' from the last integration time is infinity
    dists = jnp.concatenate(
        [dists, jnp.broadcast_to([1e10], dists[..., :1].shape)], axis=-1)
    dists = dists.astype(z_vals.dtype)  # [num_rays, num_samples]

    # multiply each distance by the norm of its corresponding direction ray
    # to convert to real world distance (accounts for non-unit directions)
    dists = dists * jnp.linalg.norm(rays_d[..., None, :], axis=-1)

    # extract RGB of each sample position along each ray
    rgb = nn.sigmoid(raw[..., :3])  # [num_rays, num_samples, 3]

    # add noise to predictions for density, can be used to (this value is strictly between [0, 1])
    # regularize network during training (prevents floater artifacts)
    noise = 0.0
    if raw_noise_std > 0.0 and rng is not None:
        noise = random.normal(rng, raw[..., 3].shape) * raw_noise_std

    # predict density of each sample along each ray (alpha channel)
    # higher values imply higher likelihood of being absorbed at this point
    alpha = 1.0 - jnp.exp(-nn.relu(raw[..., 3] + noise) * dists)

    # compute weight for RGB of each sample along each ray
    # cumprod() is used to express the idea of the ray not having reflected up to this sample yet
    # weights = alpha * tf.math.cumprod(1.0 - alpha + 1e-10, axis=-1, exclusive=True)
    alpha_ = jnp.clip(1.0 - alpha, 1e-5, 1.0)
    weights = jnp.concatenate(
        [jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1)
    weights = alpha * jnp.cumprod(weights, -1)  # [num_rays, num_samples]

    # computed weighted color of each sample along each ray
    rgb_map = jnp.einsum("ij,ijk->ik", weights, rgb)  # [num_rays, 3]

    # estimated depth map is expected distance
    depth_map = jnp.einsum("ij,ij->i", weights, z_vals)  # [num_rays]

    # sum of weights along each ray (this value is in [0, 1] up to numerical error)
    acc_map = jnp.einsum("ij->i", weights)  # [num_rays]

    # disparity map is inverse depth
    i_depth = depth_map / jnp.clip(acc_map, 1e-5)
    disp_map = 1.0 / jnp.clip(i_depth, 1e-5)

    # to composite onto a white background, use the accumulated alpha map
    if white_bkgd:
        rgb_map += 1.0 - acc_map[..., None]

    return {
        "rgb": rgb_map.astype(jnp.float32),
        "disp": disp_map.astype(jnp.float32),
        "acc": acc_map.astype(jnp.float32),
        "depth": depth_map.astype(jnp.float32),
    }, weights
Beispiel #23
0
 def testCumProd(self):
  x = jnp.arange(9).reshape(3, 3) + 1
  y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
  self.assertAllClose(np.cumprod(x, axis=1, dtype=int), y)
Beispiel #24
0
def _gen_associated_legendre(l_max: int, x: jnp.ndarray,
                             is_normalized: bool) -> jnp.ndarray:
    r"""Computes associated Legendre functions (ALFs) of the first kind.

  The ALFs of the first kind are used in spherical harmonics. The spherical
  harmonic of degree `l` and order `m` can be written as
  `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
  normalization factor and θ and φ are the colatitude and longitude,
  repectively. `N_l^m` is chosen in the way that the spherical harmonics form
  a set of orthonormal basis function of L^2(S^2). For the computational
  efficiency of spherical harmonics transform, the normalization factor is
  used in the computation of the ALFs. In addition, normalizing `P_l^m`
  avoids overflow/underflow and achieves better numerical stability. Three
  recurrence relations are used in the computation.

  Args:
    l_max: The maximum degree of the associated Legendre function. Both the
      degrees and orders are `[0, 1, 2, ..., l_max]`.
    x: A vector of type `float32`, `float64` containing the sampled points in
      spherical coordinates, at which the ALFs are computed; `x` is essentially
      `cos(θ)`. For the numerical integration used by the spherical harmonics
      transforms, `x` contains the quadrature points in the interval of
      `[-1, 1]`. There are several approaches to provide the quadrature points:
      Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
      method (`scipy.special.roots_chebyu`), and Driscoll & Healy
      method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
      transforms and convolutions on the 2-sphere." Advances in applied
      mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
      points are nearly equal-spaced along θ and provide exact discrete
      orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
      operation, `W` is a diagonal matrix containing the quadrature weights,
      and `I` is the identity matrix. The Gauss-Chebyshev points are equally
      spaced, which only provide approximate discrete orthogonality. The
      Driscoll & Healy qudarture points are equally spaced and provide the
      exact discrete orthogonality. The number of sampling points is required to
      be twice as the number of frequency points (modes) in the Driscoll & Healy
      approach, which enables FFT and achieves a fast spherical harmonics
      transform.
    is_normalized: True if the associated Legendre functions are normalized.
      With normalization, `N_l^m` is applied such that the spherical harmonics
      form a set of orthonormal basis functions of L^2(S^2).

  Returns:
    The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
    of the ALFs at `x`; the dimensions in the sequence of order, degree, and
    evalution points.
  """
    p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]))

    a_idx = jnp.arange(1, l_max + 1)
    b_idx = jnp.arange(l_max)
    if is_normalized:
        initial_value = 0.5 / jnp.sqrt(jnp.pi)  # The initial value p(0,0).
        f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
        f_b = jnp.sqrt(2.0 * b_idx + 3.0)
    else:
        initial_value = 1.0  # The initial value p(0,0).
        f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
        f_b = 2.0 * b_idx + 1.0

    p = p.at[(0, 0)].set(initial_value)

    # Compute the diagonal entries p(l,l) with recurrence.
    y = jnp.cumprod(jnp.broadcast_to(jnp.sqrt(1.0 - x * x),
                                     (l_max, x.shape[0])),
                    axis=0)
    p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y)
    diag_indices = jnp.diag_indices(l_max + 1)
    p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag)

    # Compute the off-diagonal entries with recurrence.
    p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x),
                           p[jnp.diag_indices(l_max)])
    offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1)
    p = p.at[offdiag_indices].set(p_offdiag)

    # Compute the remaining entries with recurrence.
    d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(l_max,
                                                  is_normalized=is_normalized)

    def body_fun(i, p_val):
        coeff_0 = d0_mask_3d[i]
        coeff_1 = d1_mask_3d[i]
        h = (jnp.einsum(
            'ij,ijk->ijk', coeff_0,
            jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
             jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(
                 p_val, shift=2, axis=1)))
        p_val = p_val + h
        return p_val

    # TODO(jakevdp): use some sort of fixed-point procedure here instead?
    p = p.astype(jnp.result_type(p, x, d0_mask_3d))
    if l_max > 1:
        p = lax.fori_loop(lower=2,
                          upper=l_max + 1,
                          body_fun=body_fun,
                          init_val=p)

    return p