Ejemplo n.º 1
0
def _unique_sorted_mask(ar, axis):
    aux = moveaxis(ar, axis, 0)
    if np.issubdtype(aux.dtype, np.complexfloating):
        # Work around issue in sorting of complex numbers with Nan only in the
        # imaginary component. This can be removed if sorting in this situation
        # is fixed to match numpy.
        aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
    size, *out_shape = aux.shape
    if _prod(out_shape) == 0:
        size = 1
        perm = zeros(1, dtype=int)
    else:
        perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
    aux = aux[perm]
    if aux.size:
        if dtypes.issubdtype(aux.dtype, np.inexact):
            # This is appropriate for both float and complex due to the documented behavior of np.unique:
            # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
            neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
        else:
            neq = lax.ne
        mask = ones(size, dtype=bool).at[1:].set(
            any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
    else:
        mask = zeros(size, dtype=bool)
    return aux, mask, perm
Ejemplo n.º 2
0
def roots(p, *, strip_zeros=True):
    # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
    p = atleast_1d(p)
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")

    # strip_zeros=False is unsafe because leading zeros aren't removed
    if not strip_zeros:
        if p.size > 1:
            return _roots_no_zeros(p)
        else:
            return array([])

    if all(p == 0):
        return array([])

    # factor out trivial roots
    start, end = _nonzero_range(p)
    # number of trailing zeros = number of roots at 0
    trailing_zeros = p.size - end

    # strip leading and trailing zeros
    p = p[start:end]

    if p.size < 2:
        return zeros(trailing_zeros, p.dtype)
    else:
        roots = _roots_no_zeros(p)
        # combine roots and zero roots
        roots = hstack((roots, zeros(trailing_zeros, p.dtype)))
        return roots
Ejemplo n.º 3
0
def _lu_blocked(a, block_size=128):
    """Blocked LU decomposition, as an unrolled loop."""
    m, n = a.shape
    r = min(m, n)
    pivot = jnp.zeros((r, ), dtype=jnp.int32)
    perm = jnp.arange(m, dtype=jnp.int32)
    for k in range(0, r, block_size):
        b = min(r - k, block_size)
        block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k + b])

        pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k)
        perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k])
        a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :])
        a = ops.index_update(a, ops.index[k:, k:k + b], lu_block)

        if k + b < n:
            a = ops.index_update(
                a, ops.index[k:k + b, k + b:],
                triangular_solve(a[k:k + b, k:k + b],
                                 a[k:k + b, k + b:],
                                 left_side=True,
                                 lower=True,
                                 unit_diagonal=True))
            a = ops.index_add(
                a, ops.index[k + b:, k + b:],
                -lax.dot(a[k + b:, k:k + b],
                         a[k:k + b, k + b:],
                         precision=lax.Precision.HIGHEST))
    return a, pivot, perm
Ejemplo n.º 4
0
def fftfreq(n, d=1.0):
  if isinstance(n, (list, tuple)):
    raise ValueError(
          "The n argument of jax.numpy.fft.fftfreq only takes an int. "
          "Got n = %s." % list(n))

  elif isinstance(d, (list, tuple)):
    raise ValueError(
          "The d argument of jax.numpy.fft.fftfreq only takes a single value. "
          "Got d = %s." % list(d))

  k = jnp.zeros(n)
  if n % 2 == 0:
    # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
    k = k.at[0: n // 2].set( jnp.arange(0, n // 2))

    # k[n // 2:] = jnp.arange(-n // 2, -1)
    k = k.at[n // 2:].set( jnp.arange(-n // 2, 0))

  else:
    # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
    k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1))

    # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
    k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0))

  return k / (d * n)
Ejemplo n.º 5
0
def segment_sum(data,
                segment_ids,
                num_segments=None,
                indices_are_sorted=False,
                unique_indices=False,
                bucket_size=None):  # TODO(zhangqiaorjc): use non-None default.
    """Computes the sum within segments of an array.

  Similar to TensorFlow's segment_sum:
  https://www.tensorflow.org/api_docs/python/tf/math/segment_sum

  Args:
    data: an array with the values to be summed.
    segment_ids: an array with integer dtype that indicates the segments of
      `data` (along its leading axis) to be summed. Values can be repeated and
      need not be sorted. Values outside of the range [0, num_segments) are
      dropped and do not contribute to the sum.
    num_segments: optional, an int with nonnegative value indicating the number
      of segments. The default is set to be the minimum number of segments that
      would support all indices in ``segment_ids``, calculated as
      ``max(segment_ids) + 1``.
      Since `num_segments` determines the size of the output, a static value
      must be provided to use ``segment_sum`` in a ``jit``-compiled function.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted.
    unique_indices: whether `segment_ids` is known to be free of duplicates.
    bucket_size: size of bucket to group indices into. ``segment_sum`` is
      performed on each bucket separately to improve numerical stability of
      addition. Default ``None`` means no bucketing.

  Returns:
    An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
    segment sums.
  """
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = int(num_segments)

    out = jnp.zeros((num_segments, ) + data.shape[1:], dtype=data.dtype)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,
                               segment_ids,
                               data,
                               lax.scatter_add,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False)

    # Bucketize indices and perform segment_sum on each bucket to improve
    # numerical stability.
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
        outs.append(
            segment_sum(sub_data, sub_segment_ids, num_segments,
                        indices_are_sorted, unique_indices))
    return jnp.sum(jnp.stack(outs), axis=0)
Ejemplo n.º 6
0
    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count
Ejemplo n.º 7
0
def polydiv(u, v, *, trim_leading_zeros=False):
    _check_arraylike("polydiv", u, v)
    u, v = _promote_dtypes_inexact(u, v)
    m = len(u) - 1
    n = len(v) - 1
    scale = 1. / v[0]
    q = zeros(max(m - n + 1, 1), dtype=u.dtype)  # force same dtype
    for k in range(0, m - n + 1):
        d = scale * u[k]
        q = q.at[k].set(d)
        u = u.at[k:k + n + 1].add(-d * v)
    if trim_leading_zeros:
        # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
        return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f')
    else:
        return q, u
Ejemplo n.º 8
0
def _gen_recurrence_mask(
        l_max: int,
        is_normalized: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Generates mask for recurrence relation on the remaining entries.

  The remaining entries are with respect to the diagonal and offdiagonal
  entries.

  Args:
    l_max: see `gen_normalized_legendre`.
    is_normalized: True if the recurrence mask is used by normalized associated
      Legendre functions.

  Returns:
    Arrays representing the mask used by the recurrence relations.
  """

    # Computes all coefficients.
    m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1]
    if is_normalized:
        c0 = l_mat * l_mat
        c1 = m_mat * m_mat
        c2 = 2.0 * l_mat
        c3 = (l_mat - 1.0) * (l_mat - 1.0)
        d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
        d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
    else:
        d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
        d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)

    d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
    d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
    d_zeros = jnp.zeros((l_max + 1, l_max + 1))
    d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
    d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])

    # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
    # i = jnp.arange(l_max + 1)[:, None, None]
    # j = jnp.arange(l_max + 1)[None, :, None]
    # k = jnp.arange(l_max + 1)[None, None, :]
    i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
    mask = 1.0 * (i + j - k == 0)

    d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
    d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)

    return (d0_mask_3d, d1_mask_3d)
Ejemplo n.º 9
0
 def _find_indices(self, xi):
     # find relevant edges between which xi are situated
     indices = []
     # compute distance to lower edge in unity units
     norm_distances = []
     # check for out of bounds xi
     out_of_bounds = zeros((xi.shape[1], ), dtype=bool)
     # iterate through dimensions
     for x, g in zip(xi, self.grid):
         i = searchsorted(g, x) - 1
         i = where(i < 0, 0, i)
         i = where(i > g.size - 2, g.size - 2, i)
         indices.append(i)
         norm_distances.append((x - g[i]) / (g[i + 1] - g[i]))
         if not self.bounds_error:
             out_of_bounds += x < g[0]
             out_of_bounds += x > g[-1]
     return indices, norm_distances, out_of_bounds
Ejemplo n.º 10
0
def block_diag(*arrs):
  if len(arrs) == 0:
    arrs = [jnp.zeros((1, 0))]
  arrs = jnp._promote_dtypes(*arrs)
  bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
  if bad_shapes:
    raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
                     "most 2 dimensions, got {} at argument {}."
                     .format(arrs[bad_shapes[0]], bad_shapes[0]))
  arrs = [jnp.atleast_2d(a) for a in arrs]
  acc = arrs[0]
  dtype = lax.dtype(acc)
  for a in arrs[1:]:
    _, c = a.shape
    a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
    acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
    acc = lax.concatenate([acc, a], dimension=0)
  return acc
Ejemplo n.º 11
0
def _lu_unblocked(a):
    """Unblocked LU decomposition, as a rolled loop."""
    m, n = a.shape

    def body(k, state):
        pivot, perm, a = state
        m_idx = jnp.arange(m)
        n_idx = jnp.arange(n)

        if jnp.issubdtype(a.dtype, jnp.complexfloating):
            t = a[:, k]
            magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
        else:
            magnitude = jnp.abs(a[:, k])
        i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             jnp.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
        a = a - jnp.where(
            (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]),
            jnp.array(0, dtype=a.dtype))
        return pivot, perm, a

    pivot = jnp.zeros((min(m, n), ), dtype=jnp.int32)
    perm = jnp.arange(m, dtype=jnp.int32)
    if m == 0 and n == 0:
        # If the array is empty, the loop body never executes but tracing it to a
        # jaxpr fails because the indexing cannot succeed.
        return (pivot, perm, a)
    return lax.fori_loop(0, min(m, n), body, (pivot, perm, a))
Ejemplo n.º 12
0
Archivo: signal.py Proyecto: GJBoth/jax
def _spectral_helper(x,
                     y,
                     fs=1.0,
                     window='hann',
                     nperseg=None,
                     noverlap=None,
                     nfft=None,
                     detrend_type='constant',
                     return_onesided=True,
                     scaling='density',
                     axis=-1,
                     mode='psd',
                     boundary=None,
                     padded=False):
    """LAX-backend implementation of `scipy.signal._spectral_helper`.

  Unlike the original helper function, `y` can be None for explicitly
  indicating auto-spectral (non cross-spectral) computation.  In addition to
  this, `detrend` argument is renamed to `detrend_type` for avoiding internal
  name overlap.
  """
    if mode not in ('psd', 'stft'):
        raise ValueError(f"Unknown value for mode {mode}, "
                         "must be one of: ('psd', 'stft')")

    def make_pad(mode, **kwargs):
        def pad(x, n, axis=-1):
            pad_width = [(0, 0) for unused_n in range(x.ndim)]
            pad_width[axis] = (n, n)
            return jnp.pad(x, pad_width, mode, **kwargs)

        return pad

    boundary_funcs = {
        'even': make_pad('reflect'),
        'odd': odd_ext,
        'constant': make_pad('edge'),
        'zeros': make_pad('constant', constant_values=0.0),
        None: lambda x, *args, **kwargs: x
    }

    # Check/ normalize inputs
    if boundary not in boundary_funcs:
        raise ValueError(f"Unknown boundary option '{boundary}', "
                         f"must be one of: {list(boundary_funcs.keys())}")

    axis = jax.core.concrete_or_error(operator.index, axis,
                                      "axis of windowed-FFT")
    axis = canonicalize_axis(axis, x.ndim)

    if nperseg is not None:  # if specified by user
        nperseg = jax.core.concrete_or_error(int, nperseg,
                                             "nperseg of windowed-FFT")
        if nperseg < 1:
            raise ValueError('nperseg must be a positive integer')
    # parse window; if array like, then set nperseg = win.shape
    win, nperseg = signal_helper._triage_segments(window,
                                                  nperseg,
                                                  input_length=x.shape[axis])

    if noverlap is None:
        noverlap = nperseg // 2
    else:
        noverlap = jax.core.concrete_or_error(int, noverlap,
                                              "noverlap of windowed-FFT")
    if nfft is None:
        nfft = nperseg
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT")

    _check_arraylike("_spectral_helper", x)
    x = jnp.asarray(x)

    if y is None:
        outdtype = jax.dtypes.canonicalize_dtype(
            np.result_type(x, np.complex64))
    else:
        _check_arraylike("_spectral_helper", y)
        y = jnp.asarray(y)
        outdtype = jax.dtypes.canonicalize_dtype(
            np.result_type(x, y, np.complex64))
        if mode != 'psd':
            raise ValueError(
                "two-argument mode is available only when mode=='psd'")
        if x.ndim != y.ndim:
            raise ValueError(
                "two-arguments must have the same rank ({x.ndim} vs {y.ndim})."
            )

        # Check if we can broadcast the outer axes together
        try:
            outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
                                              tuple_delete(y.shape, axis))
        except ValueError as e:
            raise ValueError('x and y cannot be broadcast together.') from e

    # Special cases for size == 0
    if y is None:
        if x.size == 0:
            return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape)
    else:
        if x.size == 0 or y.size == 0:
            outshape = tuple_insert(outershape,
                                    min([x.shape[axis], y.shape[axis]]), axis)
            emptyout = jnp.zeros(outshape)
            return emptyout, emptyout, emptyout

    # Move time-axis to the end
    if x.ndim > 1:
        if axis != -1:
            x = jnp.moveaxis(x, axis, -1)
            if y is not None and y.ndim > 1:
                y = jnp.moveaxis(y, axis, -1)

    # Check if x and y are the same length, zero-pad if necessary
    if y is not None:
        if x.shape[-1] != y.shape[-1]:
            if x.shape[-1] < y.shape[-1]:
                pad_shape = list(x.shape)
                pad_shape[-1] = y.shape[-1] - x.shape[-1]
                x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1)
            else:
                pad_shape = list(y.shape)
                pad_shape[-1] = x.shape[-1] - y.shape[-1]
                y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1)

    if nfft < nperseg:
        raise ValueError('nfft must be greater than or equal to nperseg.')
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Apply paddings
    if boundary is not None:
        ext_func = boundary_funcs[boundary]
        x = ext_func(x, nperseg // 2, axis=-1)
        if y is not None:
            y = ext_func(y, nperseg // 2, axis=-1)

    if padded:
        # Pad to integer number of windowed segments
        # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
        nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg
        zeros_shape = list(x.shape[:-1]) + [nadd]
        x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1)
        if y is not None:
            zeros_shape = list(y.shape[:-1]) + [nadd]
            y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1)

    # Handle detrending and window functions
    if not detrend_type:

        def detrend_func(d):
            return d
    elif not hasattr(detrend_type, '__call__'):

        def detrend_func(d):
            return detrend(d, type=detrend_type, axis=-1)
    elif axis != -1:
        # Wrap this function so that it receives a shape that it could
        # reasonably expect to receive.
        def detrend_func(d):
            d = jnp.moveaxis(d, axis, -1)
            d = detrend_type(d)
            return jnp.moveaxis(d, -1, axis)
    else:
        detrend_func = detrend_type

    if np.result_type(win, np.complex64) != outdtype:
        win = win.astype(outdtype)

    # Determine scale
    if scaling == 'density':
        scale = 1.0 / (fs * (win * win).sum())
    elif scaling == 'spectrum':
        scale = 1.0 / win.sum()**2
    else:
        raise ValueError(f'Unknown scaling: {scaling}')
    if mode == 'stft':
        scale = jnp.sqrt(scale)

    # Determine onesided/ two-sided
    if return_onesided:
        sides = 'onesided'
        if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
            sides = 'twosided'
            warnings.warn('Input data is complex, switching to '
                          'return_onesided=False')
    else:
        sides = 'twosided'

    if sides == 'twosided':
        freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs)
    elif sides == 'onesided':
        freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs)

    # Perform the windowed FFTs
    result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)

    if y is not None:
        # All the same operations on the y data
        result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
                               sides)
        result = jnp.conjugate(result) * result_y
    elif mode == 'psd':
        result = jnp.conjugate(result) * result

    result *= scale

    if sides == 'onesided' and mode == 'psd':
        end = None if nfft % 2 else -1
        result = result.at[..., 1:end].mul(2)

    time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1,
                      nperseg - noverlap) / fs
    if boundary is not None:
        time -= (nperseg / 2) / fs

    result = result.astype(outdtype)

    # All imaginary parts are zero anyways
    if y is None and mode != 'stft':
        result = result.real

    # Move frequency axis back to axis where the data came from
    result = jnp.moveaxis(result, -1, axis)

    return freqs, time, result
Ejemplo n.º 13
0
def _gen_associated_legendre(l_max: int, x: jnp.ndarray,
                             is_normalized: bool) -> jnp.ndarray:
    r"""Computes associated Legendre functions (ALFs) of the first kind.

  The ALFs of the first kind are used in spherical harmonics. The spherical
  harmonic of degree `l` and order `m` can be written as
  `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
  normalization factor and θ and φ are the colatitude and longitude,
  repectively. `N_l^m` is chosen in the way that the spherical harmonics form
  a set of orthonormal basis function of L^2(S^2). For the computational
  efficiency of spherical harmonics transform, the normalization factor is
  used in the computation of the ALFs. In addition, normalizing `P_l^m`
  avoids overflow/underflow and achieves better numerical stability. Three
  recurrence relations are used in the computation.

  Args:
    l_max: The maximum degree of the associated Legendre function. Both the
      degrees and orders are `[0, 1, 2, ..., l_max]`.
    x: A vector of type `float32`, `float64` containing the sampled points in
      spherical coordinates, at which the ALFs are computed; `x` is essentially
      `cos(θ)`. For the numerical integration used by the spherical harmonics
      transforms, `x` contains the quadrature points in the interval of
      `[-1, 1]`. There are several approaches to provide the quadrature points:
      Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
      method (`scipy.special.roots_chebyu`), and Driscoll & Healy
      method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
      transforms and convolutions on the 2-sphere." Advances in applied
      mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
      points are nearly equal-spaced along θ and provide exact discrete
      orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
      operation, `W` is a diagonal matrix containing the quadrature weights,
      and `I` is the identity matrix. The Gauss-Chebyshev points are equally
      spaced, which only provide approximate discrete orthogonality. The
      Driscoll & Healy qudarture points are equally spaced and provide the
      exact discrete orthogonality. The number of sampling points is required to
      be twice as the number of frequency points (modes) in the Driscoll & Healy
      approach, which enables FFT and achieves a fast spherical harmonics
      transform.
    is_normalized: True if the associated Legendre functions are normalized.
      With normalization, `N_l^m` is applied such that the spherical harmonics
      form a set of orthonormal basis functions of L^2(S^2).

  Returns:
    The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
    of the ALFs at `x`; the dimensions in the sequence of order, degree, and
    evalution points.
  """
    p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]))

    a_idx = jnp.arange(1, l_max + 1)
    b_idx = jnp.arange(l_max)
    if is_normalized:
        initial_value = 0.5 / jnp.sqrt(jnp.pi)  # The initial value p(0,0).
        f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
        f_b = jnp.sqrt(2.0 * b_idx + 3.0)
    else:
        initial_value = 1.0  # The initial value p(0,0).
        f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
        f_b = 2.0 * b_idx + 1.0

    p = p.at[(0, 0)].set(initial_value)

    # Compute the diagonal entries p(l,l) with recurrence.
    y = jnp.cumprod(jnp.broadcast_to(jnp.sqrt(1.0 - x * x),
                                     (l_max, x.shape[0])),
                    axis=0)
    p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y)
    diag_indices = jnp.diag_indices(l_max + 1)
    p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag)

    # Compute the off-diagonal entries with recurrence.
    p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x),
                           p[jnp.diag_indices(l_max)])
    offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1)
    p = p.at[offdiag_indices].set(p_offdiag)

    # Compute the remaining entries with recurrence.
    d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(l_max,
                                                  is_normalized=is_normalized)

    def body_fun(i, p_val):
        coeff_0 = d0_mask_3d[i]
        coeff_1 = d1_mask_3d[i]
        h = (jnp.einsum(
            'ij,ijk->ijk', coeff_0,
            jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
             jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(
                 p_val, shift=2, axis=1)))
        p_val = p_val + h
        return p_val

    if l_max > 1:
        p = lax.fori_loop(lower=2,
                          upper=l_max + 1,
                          body_fun=body_fun,
                          init_val=p)

    return p
Ejemplo n.º 14
0
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray,
                     is_normalized: bool) -> jnp.ndarray:
    """Generates derivatives of associated Legendre functions of the first kind.

  Args:
    p: The 3D array containing the values of associated Legendre functions; the
      dimensions are in the sequence of order (m), degree (l), and evalution
      points.
    x: A vector of type `float32` or `float64` containing the sampled points.
    is_normalized: True if the associated Legendre functions are normalized.
  Returns:
    The 3D array representing the derivatives of associated Legendre functions
    of the first kind.
  """

    num_m, num_l, num_x = p.shape

    # p_{l-1}^m.
    p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]

    # p_{l-1}^{m+2}.
    p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]

    # p_{l-1}^{m-2}.
    p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]

    # Derivative computation requires negative orders.
    if is_normalized:
        raise NotImplementedError(
            'Negative orders for normalization is not implemented yet.')
    else:
        if num_l > 1:
            l_vec = jnp.arange(1, num_l - 1)
            p_p1 = p[1, 1:num_l - 1, :]
            coeff = -1.0 / ((l_vec + 1) * l_vec)
            update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
            p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)

        if num_l > 2:
            l_vec = jnp.arange(2, num_l - 1)
            p_p2 = p[2, 2:num_l - 1, :]
            coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
            update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
            p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)

    m_mat, l_mat = jnp.mgrid[:num_m, :num_l]

    coeff_zeros = jnp.zeros((num_m, num_l))
    upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
    zero_vec = jnp.zeros((num_l, ))

    a0 = -0.5 / (m_mat - 1.0)
    a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
    a0_masked = a0_masked.at[1, :].set(zero_vec)

    b0 = l_mat + m_mat
    c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
    c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
    c0_masked = c0_masked.at[1, :].set(zero_vec)

    # p_l^{m-1}.
    p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
               jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))

    d0 = -0.5 / (m_mat + 1.0)
    d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
    e0 = d0 * b0 * (b0 + 1.0)
    e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])

    # p_l^{m+1}.
    p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
               jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))

    f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
    f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
    p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked,
                              p_mm1_l) - 0.5 * p_mp1_l

    # Special treatment of the singularity at m = 1.
    if num_m > 1:
        l_vec = jnp.arange(num_l)
        g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
        if num_l > 2:
            g0 = g0 - p[2, :, :]
        p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
        p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
        p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, )))

    return p_derivative
Ejemplo n.º 15
0
def _unique(ar,
            axis,
            return_index=False,
            return_inverse=False,
            return_counts=False,
            size=None,
            fill_value=None,
            return_true_size=False):
    """
  Find the unique elements of an array along a particular axis.
  """
    if ar.shape[axis] == 0 and size and fill_value is None:
        raise ValueError(
            "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified"
        )

    aux, mask, perm = _unique_sorted_mask(ar, axis)
    if size is None:
        ind = core.concrete_or_error(
            None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT)
    else:
        ind = nonzero(mask, size=size)[0]
    result = aux[ind] if aux.size else aux
    if fill_value is not None:
        fill_value = asarray(fill_value, dtype=result.dtype)
    if size is not None and fill_value is not None:
        if result.shape[0]:
            valid = lax.expand_dims(
                arange(size) < mask.sum(), tuple(range(1, result.ndim)))
            result = where(valid, result, fill_value)
        else:
            result = full_like(result,
                               fill_value,
                               shape=(size, *result.shape[1:]))
    result = moveaxis(result, 0, axis)

    ret = (result, )
    if return_index:
        if aux.size:
            ret += (perm[ind], )
        else:
            ret += (perm, )
    if return_inverse:
        if aux.size:
            imask = cumsum(mask) - 1
            inv_idx = zeros(mask.shape,
                            dtype=dtypes.canonicalize_dtype(dtypes.int_))
            inv_idx = inv_idx.at[perm].set(imask)
        else:
            inv_idx = zeros(ar.shape[axis], dtype=int)
        ret += (inv_idx, )
    if return_counts:
        if aux.size:
            if size is None:
                idx = append(nonzero(mask)[0], mask.size)
            else:
                idx = nonzero(mask, size=size + 1)[0]
                idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
            ret += (diff(idx), )
        elif ar.shape[axis]:
            ret += (array([ar.shape[axis]],
                          dtype=dtypes.canonicalize_dtype(dtypes.int_)), )
        else:
            ret += (empty(0, dtype=int), )
    if return_true_size:
        # Useful for internal uses of unique().
        ret += (mask.sum(), )
    return ret[0] if len(ret) == 1 else ret