Ejemplo n.º 1
def segment_sum(data,
                bucket_size=None):  # TODO(zhangqiaorjc): use non-None default.
    """Computes the sum within segments of an array.

  Similar to TensorFlow's segment_sum:

    data: an array with the values to be summed.
    segment_ids: an array with integer dtype that indicates the segments of
      `data` (along its leading axis) to be summed. Values can be repeated and
      need not be sorted. Values outside of the range [0, num_segments) are
      wrapped into that range by applying jnp.mod.
    num_segments: optional, an int with positive value indicating the number of
      segments. The default is set to be the minimum number of segments that
      would support all positive and negative indices in `segment_ids`
      calculated as ``max(max(segment_ids) + 1, max(-segment_ids))``.
      Since `num_segments` determines the size of the output, a static value
      must be provided to use `segment_sum` in a `jit`-compiled function.
    indices_are_sorted: whether `segment_ids` is known to be sorted.
    unique_indices: whether `segment_ids` is known to be free of duplicates.
    bucket_size: size of bucket to group indices into. segment_sum is performed
      on each bucket separately to improve numerical stability of addition.
      Default `None` means no bucketing.

    An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
    segment sums.
    if num_segments is None:
        num_segments = max(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
    num_segments = int(num_segments)

    out = jnp.zeros((num_segments, ) + data.shape[1:], dtype=data.dtype)
    segment_ids = jnp.mod(segment_ids, num_segments)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return index_add(out, segment_ids, data, indices_are_sorted,

    # Bucketize indices and perform segment_sum on each bucket to improve
    # numerical stability.
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
            segment_sum(sub_data, sub_segment_ids, num_segments,
                        indices_are_sorted, unique_indices))
    return jnp.sum(jnp.stack(outs), axis=0)
Ejemplo n.º 2
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
Ejemplo n.º 3
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None,
                    mode: Optional[lax.GatherScatterMode] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        out = jnp.full((num_segments, ) + data.shape[1:],
                       _get_identity(scatter_op, dtype),
        return _scatter_update(out,

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    out = jnp.full((num_buckets, num_segments) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
    out = _scatter_update(
        np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size),
                     segment_ids[None, :]],
    return reducer(out, axis=0).astype(dtype)
Ejemplo n.º 4
def sph_harm(m: jnp.ndarray,
             n: jnp.ndarray,
             theta: jnp.ndarray,
             phi: jnp.ndarray,
             n_max: Optional[int] = None) -> jnp.ndarray:
    r"""Computes the spherical harmonics.

  The JAX version has one extra argument `n_max`, the maximum value in `n`.

  The spherical harmonic of degree `n` and order `m` can be written as
  :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
  where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
  {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
  :math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
  chosen in the way that the spherical harmonics form a set of orthonormal basis
  functions of :math:`L^2(S^2)`.

    m: The order of the harmonic; must have `|m| <= n`. Return values for
      `|m| > n` ara undefined.
    n: The degree of the harmonic; must have `n >= 0`. The standard notation for
      degree in descriptions of spherical harmonics is `l (lower case L)`. We
      use `n` here to be consistent with `scipy.special.sph_harm`. Return
      values for `n < 0` are undefined.
    theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
    phi: The polar (colatitudinal) coordinate; must be in [0, pi].
    n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
      maximum value of `n`, the results are clipped to `n_max`. For example,
      `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
      acutually returns
      `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
    A 1D array containing the spherical harmonics at (m, n, theta, phi).

    if jnp.isscalar(phi):
        phi = jnp.array([phi])

    if n_max is None:
        n_max = jnp.max(n)
    n_max = core.concrete_or_error(
        int, n_max,
        'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
        'be statically specified to use `sph_harm` within JAX transformations.'

    return _sph_harm(m, n, theta, phi, n_max)
Ejemplo n.º 5
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")

    out = jnp.full((num_segments, ) + data.shape[1:],
                   _get_identity(scatter_op, dtype),

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
            _segment_update(name, sub_data, sub_segment_ids, scatter_op,
                            num_segments, indices_are_sorted, unique_indices))
    return reducer(jnp.stack(outs), axis=0).astype(dtype)
Ejemplo n.º 6
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            # Use jnp.array(nan) to avoid false positives in debug_nans
            # (see https://github.com/google/jax/issues/7634)
            out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
    return out
Ejemplo n.º 7
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            out = jnp.where(sign < 0, np.nan, out)
    return out
Ejemplo n.º 8
def segment_sum(
    data: Array,
    segment_ids: Array,
    num_segments: Optional[int] = None,
    indices_are_sorted: bool = False,
    unique_indices: bool = False,
    # TODO(zhangqiaorjc): use non-None default for bucket_size.
    bucket_size: Optional[int] = None
) -> Array:
    """Computes the sum within segments of an array.

  Similar to TensorFlow's segment_sum:

    data: an array with the values to be summed.
    segment_ids: an array with integer dtype that indicates the segments of
      `data` (along its leading axis) to be summed. Values can be repeated and
      need not be sorted. Values outside of the range [0, num_segments) are
      dropped and do not contribute to the sum.
    num_segments: optional, an int with nonnegative value indicating the number
      of segments. The default is set to be the minimum number of segments that
      would support all indices in ``segment_ids``, calculated as
      ``max(segment_ids) + 1``.
      Since `num_segments` determines the size of the output, a static value
      must be provided to use ``segment_sum`` in a ``jit``-compiled function.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted.
    unique_indices: whether `segment_ids` is known to be free of duplicates.
    bucket_size: size of bucket to group indices into. ``segment_sum`` is
      performed on each bucket separately to improve numerical stability of
      addition. Default ``None`` means no bucketing.

    An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
    segment sums.

    Simple 1D segment sum:

    >>> data = jnp.arange(5)
    >>> segment_ids = jnp.array([0, 0, 1, 1, 2])
    >>> segment_sum(data, segment_ids)
    DeviceArray([1, 5, 4], dtype=int32)

    Using JIT requires static `num_segments`:

    >>> from jax import jit
    >>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
    DeviceArray([1, 5, 4], dtype=int32)
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")

    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")

    out = jnp.zeros((num_segments, ) + data.shape[1:], dtype=data.dtype)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,

    # Bucketize indices and perform segment_sum on each bucket to improve
    # numerical stability.
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
            segment_sum(sub_data, sub_segment_ids, num_segments,
                        indices_are_sorted, unique_indices))
    return jnp.sum(jnp.stack(outs), axis=0)