Exemple #1
0
def unique(ar,
           return_index=False,
           return_inverse=False,
           return_counts=False,
           axis: Optional[int] = None,
           *,
           size=None,
           fill_value=None):
    _check_arraylike("unique", ar)
    if size is None:
        ar = core.concrete_or_error(
            None, ar,
            "The error arose for the first argument of jnp.unique(). " +
            UNIQUE_SIZE_HINT)
    else:
        size = core.concrete_or_error(
            operator.index, size,
            "The error arose for the size argument of jnp.unique(). " +
            UNIQUE_SIZE_HINT)
    ar = asarray(ar)
    if axis is None:
        axis = 0
        ar = ar.flatten()
    axis = core.concrete_or_error(operator.index, axis,
                                  "axis argument of jnp.unique()")
    return _unique(ar,
                   axis,
                   return_index,
                   return_inverse,
                   return_counts,
                   size=size,
                   fill_value=fill_value)
Exemple #2
0
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
    _check_arraylike("setdiff1d", ar1, ar2)
    if size is None:
        ar1 = core.concrete_or_error(None, ar1,
                                     "The error arose in setdiff1d()")
    else:
        size = core.concrete_or_error(operator.index, size,
                                      "The error arose in setdiff1d()")
    ar1 = asarray(ar1)
    fill_value = asarray(0 if fill_value is None else fill_value,
                         dtype=ar1.dtype)
    if ar1.size == 0:
        return full_like(ar1, fill_value, shape=size or 0)
    if not assume_unique:
        ar1 = unique(ar1, size=size and ar1.size)
    mask = in1d(ar1, ar2, invert=True)
    if size is None:
        return ar1[mask]
    else:
        if not (assume_unique or size is None):
            # Set mask to zero at locations corresponding to unique() padding.
            n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum()
            mask = where(arange(ar1.size) < n_unique, mask, False)
        return where(
            arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
Exemple #3
0
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
    _check_arraylike("intersect1d", ar1, ar2)
    ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
    ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")

    if not assume_unique:
        if return_indices:
            ar1, ind1 = unique(ar1, return_index=True)
            ar2, ind2 = unique(ar2, return_index=True)
        else:
            ar1 = unique(ar1)
            ar2 = unique(ar2)
    else:
        ar1 = ravel(ar1)
        ar2 = ravel(ar2)

    if return_indices:
        aux, mask, aux_sort_indices = _intersect1d_sorted_mask(
            ar1, ar2, return_indices)
    else:
        aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)

    int1d = aux[:-1][mask]

    if return_indices:
        ar1_indices = aux_sort_indices[:-1][mask]
        ar2_indices = aux_sort_indices[1:][mask] - ar1.size
        if not assume_unique:
            ar1_indices = ind1[ar1_indices]
            ar2_indices = ind2[ar2_indices]

        return int1d, ar1_indices, ar2_indices
    else:
        return int1d
Exemple #4
0
def svd(a: Any,
        full_matrices: bool,
        compute_uv: bool = True,
        hermitian: bool = False,
        max_iterations: int = 10) -> Union[Any, Sequence[Any]]:
    """Singular value decomposition.

  Args:
    a: A matrix of shape `m x n`.
    full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
      respectively. If False, the shapes are `m x k` and `k x n`, respectively,
      where `k = min(m, n)`.
    compute_uv: Whether to compute also `u` and `v` in addition to `s`.
    hermitian: True if `a` is Hermitian.
    max_iterations: The predefined maximum number of iterations of QDWH.

  Returns:
    A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
    `s` is vector of length `k` containing the singular values in the
    non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh`
    depend on the value of `full_matrices`. For `compute_uv=False`,
    only `s` is returned.
  """
    full_matrices = core.concrete_or_error(
        bool, full_matrices, 'The `full_matrices` argument must be statically '
        'specified to use `svd` within JAX transformations.')

    compute_uv = core.concrete_or_error(
        bool, compute_uv, 'The `compute_uv` argument must be statically '
        'specified to use `svd` within JAX transformations.')

    hermitian = core.concrete_or_error(
        bool, hermitian, 'The `hermitian` argument must be statically '
        'specified to use `qdwh` within JAX transformations.')

    max_iterations = core.concrete_or_error(
        int, max_iterations,
        'The `max_iterations` argument must be statically '
        'specified to use `qdwh` within JAX transformations.')

    # QDWH algorithm fails at zero-matrix `A` and produces all NaNs, which can
    # be seen from a dynamically weighted Halley (DWH) iteration:
    # X_{k+1} = X_k(a_k I + b_k {X_k}^H X_k)(I + c_k {X_k}^H X_k)^{−1} and
    # X_0 = A/alpha, where alpha = ||A||_2, the triplet (a_k, b_k, c_k) are
    # weighting parameters, and X_k denotes the k^{th} iterate.
    return jax.lax.cond(jnp.all(a == 0),
                        functools.partial(_zero_svd,
                                          full_matrices=full_matrices,
                                          compute_uv=compute_uv),
                        functools.partial(_qdwh_svd,
                                          full_matrices=full_matrices,
                                          compute_uv=compute_uv,
                                          hermitian=hermitian,
                                          max_iterations=max_iterations),
                        operand=(a))
Exemple #5
0
def union1d(ar1, ar2, *, size=None, fill_value=None):
    _check_arraylike("union1d", ar1, ar2)
    if size is None:
        ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
        ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()")
    else:
        size = core.concrete_or_error(operator.index, size,
                                      "The error arose in union1d()")
    return unique(concatenate((ar1, ar2), axis=None),
                  size=size,
                  fill_value=fill_value)
Exemple #6
0
def lpmn_values(m: int, n: int, z: jnp.ndarray,
                is_normalized: bool) -> jnp.ndarray:
    r"""The associated Legendre functions (ALFs) of the first kind.

  Unlike `lpmn`, this function only computes the values of ALFs.
  The ALFs of the first kind can be used in spherical harmonics. The
  spherical harmonic of degree `l` and order `m` can be written as
  :math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`,
  where :math:`N_l^m` is the normalization factor and θ and φ are the
  colatitude and longitude, repectively. :math:`N_l^m` is chosen in the
  way that the spherical harmonics form a set of orthonormal basis function
  of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow
  and achieves better numerical stability.

  Args:
    m: The maximum order of the associated Legendre functions.
    n: The maximum degree of the associated Legendre function, often called
      `l` in describing ALFs. Both the degrees and orders are
      `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
    z: A vector of type `float32` or `float64` containing the sampling
      points at which the ALFs are computed.
    is_normalized: True if the associated Legendre functions are normalized.
      With normalization, :math:`N_l^m` is applied such that the spherical
      harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`.

  Returns:
    A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing
    the values of the associated Legendre functions of the first kind. The
    return type matches the type of `z`.

  Raises:
    TypeError if elements of array `z` are not in (float32, float64).
    ValueError if array `z` is not 1D.
    NotImplementedError if `m!=n`.
  """
    dtype = lax.dtype(z)
    if dtype not in (jnp.float32, jnp.float64):
        raise TypeError(
            'z.dtype={} is not supported, see docstring for supported types.'.
            format(dtype))

    if z.ndim != 1:
        raise ValueError('z must be a 1D array.')

    m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
    n = core.concrete_or_error(int, n, 'Argument n of lpmn.')

    if m != n:
        raise NotImplementedError(
            'Computations for m!=n are not yet supported.')

    l_max = n

    return _gen_associated_legendre(l_max, z, is_normalized)
Exemple #7
0
def _make_1d_grid_from_slice(s: slice, op_name: str):
    start = core.concrete_or_error(None, s.start,
                                   f"slice start of jnp.{op_name}") or 0
    stop = core.concrete_or_error(None, s.stop, f"slice stop of jnp.{op_name}")
    step = core.concrete_or_error(None, s.step,
                                  f"slice step of jnp.{op_name}") or 1
    if np.iscomplex(step):
        newobj = linspace(start, stop, int(abs(step)))
    else:
        newobj = arange(start, stop, step)

    return newobj
Exemple #8
0
def svd(a: jnp.ndarray,
        is_hermitian: bool = False,
        max_iterations: int = 10) -> Sequence[jnp.ndarray]:
    """Singular value decomposition.

  Args:
    a: A matrix of shape `m x n`.
    is_hermitian: True if `a` is Hermitian.
    max_iterations: The predefined maximum number of iterations of QDWH.

  Returns:
    A 3-tuple (`u`, `s`, `vh`), where `u` is a unitary matrix of shape `m x k`,
    `s` is vector of length `k` containing the singular values in the descending
    order, `vh` is a unitary matrix of shape `k x n`, `k = min(m, n)`, and
    `a = (u * s) @ vh`.
  """

    is_hermitian = core.concrete_or_error(
        bool, is_hermitian, 'The `is_hermitian` argument must be statically '
        'specified to use `qdwh` within JAX transformations.')

    max_iterations = core.concrete_or_error(
        int, max_iterations,
        'The `max_iterations` argument must be statically '
        'specified to use `qdwh` within JAX transformations.')

    m, n = a.shape

    is_flip = False
    if m < n:
        a = a.T.conj()
        m, n = a.shape
        is_flip = True

    reduce_to_square = False
    if m > 1.15 * n:
        m = n
        q, a = lax.linalg.qr(a, full_matrices=False)
        reduce_to_square = True

    u_out, s_out, v_out = _svd(a, is_hermitian, max_iterations)

    if reduce_to_square:
        u_out = q @ u_out

    if is_flip:
        return (v_out, s_out, u_out.T.conj())

    return (u_out, s_out, v_out.T.conj())
Exemple #9
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float_,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    return _one_hot(x, num_classes, dtype=dtype, axis=axis)
Exemple #10
0
def one_hot(x, num_classes, *, dtype=jnp.float64):
  """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

  >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
  DeviceArray([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

  >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
  DeviceArray([[0., 0., 0.],
               [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
  """
  num_classes = core.concrete_or_error(
      int, num_classes,
      "The error arose in jax.nn.one_hot argument `num_classes`.")
  dtype = dtypes.canonicalize_dtype(dtype)
  x = jnp.asarray(x)
  lhs = x[..., jnp.newaxis]
  rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim)
  return jnp.array(lhs == rhs, dtype=dtype)
Exemple #11
0
def setxor1d(ar1, ar2, assume_unique=False):
    _check_arraylike("setxor1d", ar1, ar2)
    ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
    ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")

    ar1 = ravel(ar1)
    ar2 = ravel(ar2)

    if not assume_unique:
        ar1 = unique(ar1)
        ar2 = unique(ar2)

    aux = concatenate((ar1, ar2))
    if aux.size == 0:
        return aux

    aux = sort(aux)
    flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
    return aux[flag[1:] & flag[:-1]]
Exemple #12
0
def lpmn(m: int, n: int, z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """The associated Legendre functions (ALFs) of the first kind.

  Args:
    m: The maximum order of the associated Legendre functions.
    n: The maximum degree of the associated Legendre function, often called
      `l` in describing ALFs. Both the degrees and orders are
      `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
    z: A vector of type `float32` or `float64` containing the sampling
      points at which the ALFs are computed.

  Returns:
    A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing
    the values and derivatives of the associated Legendre functions of the
    first kind. The return type matches the type of `z`.

  Raises:
    TypeError if elements of array `z` are not in (float32, float64).
    ValueError if array `z` is not 1D.
    NotImplementedError if `m!=n`.
  """
    dtype = lax.dtype(z)
    if dtype not in (jnp.float32, jnp.float64):
        raise TypeError(
            'z.dtype={} is not supported, see docstring for supported types.'.
            format(dtype))

    if z.ndim != 1:
        raise ValueError('z must be a 1D array.')

    m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
    n = core.concrete_or_error(int, n, 'Argument n of lpmn.')

    if m != n:
        raise NotImplementedError(
            'Computations for m!=n are not yet supported.')

    l_max = n
    is_normalized = False
    p_vals = _gen_associated_legendre(l_max, z, is_normalized)
    p_derivatives = _gen_derivatives(p_vals, z, is_normalized)

    return (p_vals, p_derivatives)
Exemple #13
0
def polyder(p, m=1):
  _check_arraylike("polyder", p)
  m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
  p, = _promote_dtypes_inexact(p)
  if m < 0:
    raise ValueError("Order of derivative must be positive")
  if m == 0:
    return p
  coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
  return p[:-m] * coeff
Exemple #14
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None,
                    mode: Optional[lax.GatherScatterMode] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")


    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        out = jnp.full((num_segments, ) + data.shape[1:],
                       _get_identity(scatter_op, dtype),
                       dtype=dtype)
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False,
                               mode=mode)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    out = jnp.full((num_buckets, num_segments) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)
    out = _scatter_update(
        out,
        np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size),
                     segment_ids[None, :]],
        data,
        scatter_op,
        indices_are_sorted,
        unique_indices,
        normalize_indices=False,
        mode=mode)
    return reducer(out, axis=0).astype(dtype)
Exemple #15
0
def qdwh(x, is_symmetric, max_iterations=10):
    """QR-based dynamically weighted Halley iteration for polar decomposition.

  Args:
    x: A full-rank matrix of shape `m x n` with `m  >= n`.
    is_symmetric: True if `x` is symmetric.
    max_iterations: The predefined maximum number of iterations.

  Returns:
    A four-tuple of (u, h, num_iters, is_converged) containing the
    polar decomposition of `x = u * h`, the number of iterations to compute `u`,
    and `is_converged`, whose value is `True` when the convergence is achieved
    within the maximum number of iterations.
  """
    m, n = x.shape

    if m < n:
        raise ValueError('The input matrix of shape m x n must have m >= n.')

    max_iterations = core.concrete_or_error(
        int, max_iterations,
        'The `max_iterations` argument must be statically '
        'specified to use `qdwh` within JAX transformations.')

    is_symmetric = core.concrete_or_error(
        bool, is_symmetric, 'The `is_symmetric` argument must be statically '
        'specified to use `qdwh` within JAX transformations.')

    if is_symmetric:
        eps = jnp.finfo(x.dtype).eps
        tol = 50.0 * eps
        relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x)
        if relative_diff > tol:
            raise ValueError(
                'The input `x` is NOT symmetric because '
                '`norm(x-x.H) / norm(x)` is {}, which is greater than '
                'the tolerance {}.'.format(relative_diff, tol))

    with jax.default_matmul_precision('float32'):
        u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations)

    return u, h, num_iters, is_converged
Exemple #16
0
def multigammaln(a, d):
  d = core.concrete_or_error(int, d, "d argument of multigammaln")
  a, d = _promote_args_inexact("multigammaln", a, d)

  constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
                             lax.sub(d, _constant_like(a, 1))),
                     lax.log(_constant_like(a, np.pi)))
  res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
                        lax.div(jnp.arange(d), _constant_like(a, 2))),
               axis=-1)
  return res + constant
Exemple #17
0
def _ensure_optional_axes(x):
    def force(x):
        if x is None:
            return None
        try:
            return operator.index(x)
        except TypeError:
            return tuple(i if isinstance(i, str) else operator.index(i)
                         for i in x)

    return core.concrete_or_error(
        force, x, "The axis argument must be known statically.")
Exemple #18
0
def multigammaln(a, d):
    d = core.concrete_or_error(int, d, "d argument of multigammaln")
    a, d_ = _promote_args_inexact("multigammaln", a, d)

    constant = lax.mul(
        lax.mul(lax.mul(_lax_const(a, 0.25), d_),
                lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi)))
    b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
                  axis=-1)
    return res + constant
Exemple #19
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Exemple #20
0
def _coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
  """Create COO-format sparse matrix from a dense matrix.

  Args:
    mat : array to be converted to COO.
    nse : number of specified entries in ``mat``
    index_dtype : dtype of sparse indices

  Returns:
    data : array of shape ``(nse,)`` and dtype ``mat.dtype``
    row : array of shape ``(nse,)`` and dtype ``index_dtype``
    col : array of shape ``(nse,)`` and dtype ``index_dtype``
  """
  mat = jnp.asarray(mat)
  nse = core.concrete_or_error(operator.index, nse, "nse argument of coo_fromdense()")
  return coo_fromdense_p.bind(mat, nse=nse, index_dtype=index_dtype)
Exemple #21
0
def csr_fromdense(mat, *, nnz, index_dtype=np.int32):
  """Create CSR-format sparse matrix from a dense matrix.

  Args:
    mat : array to be converted to CSR.
    nnz : number of nonzero entries in ``mat``
    index_dtype : dtype of sparse indices

  Returns:
    data : array of shape ``(nnz,)`` and dtype ``mat.dtype``.
    indices : array of shape ``(nnz,)`` and dtype ``index_dtype``
    indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype``
  """
  mat = jnp.asarray(mat)
  nnz = core.concrete_or_error(operator.index, nnz, "nnz argument of csr_fromdense()")
  return csr_fromdense_p.bind(mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
Exemple #22
0
def sph_harm(m: jnp.ndarray,
             n: jnp.ndarray,
             theta: jnp.ndarray,
             phi: jnp.ndarray,
             n_max: Optional[int] = None) -> jnp.ndarray:
    r"""Computes the spherical harmonics.

  The JAX version has one extra argument `n_max`, the maximum value in `n`.

  The spherical harmonic of degree `n` and order `m` can be written as
  :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
  where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
  {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
  :math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
  chosen in the way that the spherical harmonics form a set of orthonormal basis
  functions of :math:`L^2(S^2)`.

  Args:
    m: The order of the harmonic; must have `|m| <= n`. Return values for
      `|m| > n` ara undefined.
    n: The degree of the harmonic; must have `n >= 0`. The standard notation for
      degree in descriptions of spherical harmonics is `l (lower case L)`. We
      use `n` here to be consistent with `scipy.special.sph_harm`. Return
      values for `n < 0` are undefined.
    theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
    phi: The polar (colatitudinal) coordinate; must be in [0, pi].
    n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
      maximum value of `n`, the results are clipped to `n_max`. For example,
      `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
      acutually returns
      `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
  Returns:
    A 1D array containing the spherical harmonics at (m, n, theta, phi).
  """

    if jnp.isscalar(phi):
        phi = jnp.array([phi])

    if n_max is None:
        n_max = jnp.max(n)
    n_max = core.concrete_or_error(
        int, n_max,
        'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
        'be statically specified to use `sph_harm` within JAX transformations.'
    )

    return _sph_harm(m, n, theta, phi, n_max)
Exemple #23
0
def coo_fromdense(mat, *, nse=None, index_dtype=jnp.int32):
  """Create a COO-format sparse matrix from a dense matrix.

  Args:
    mat : array to be converted to COO.
    nse : number of specified entries in ``mat``. If not specified,
      it will be computed from the input matrix.
    index_dtype : dtype of sparse indices

  Returns:
    mat_coo : COO representation of the matrix.
  """
  if nse is None:
    nse = (mat != 0).sum()
  nse = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
  return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype),
             shape=mat.shape, rows_sorted=True)
Exemple #24
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")

    out = jnp.full((num_segments, ) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
        outs.append(
            _segment_update(name, sub_data, sub_segment_ids, scatter_op,
                            num_segments, indices_are_sorted, unique_indices))
    return reducer(jnp.stack(outs), axis=0).astype(dtype)
Exemple #25
0
def polyint(p, m=1, k=None):
  m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
  k = 0 if k is None else k
  _check_arraylike("polyint", p, k)
  p, k = _promote_dtypes_inexact(p, k)
  if m < 0:
    raise ValueError("Order of integral must be positive (see polyder)")
  k = atleast_1d(k)
  if len(k) == 1:
    k = full((m,), k[0])
  if k.shape != (m,):
    raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
  if m == 0:
    return p
  else:
    coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
    return true_divide(concatenate((p, k)), coeff)
Exemple #26
0
def qdwh(x,
         *,
         is_hermitian=False,
         max_iterations=None,
         eps=None,
         dynamic_shape: Optional[Tuple[int, int]] = None):
    """QR-based dynamically weighted Halley iteration for polar decomposition.

  Args:
    x: A full-rank matrix, with shape `M x N`. The matrix may be
      padded up to that size from a smaller true shape (``dynamic_shape``).
    is_hermitian: True if `x` is Hermitian. Default to `False`.
    eps: The final result will satisfy
      ``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate.
    max_iterations: Iterations will terminate after this many steps even if the
      above is unsatisfied.
    dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional.

  Returns:
    A four-tuple of (u, h, num_iters, is_converged) containing the
    polar decomposition of `x = u * h`, the number of iterations to compute `u`,
    and `is_converged`, whose value is `True` when the convergence is achieved
    within the maximum number of iterations.
  """
    is_hermitian = core.concrete_or_error(
        bool, is_hermitian, 'The `is_hermitian` argument must be statically '
        'specified to use `qdwh` within JAX transformations.')

    if max_iterations is None:
        max_iterations = 10

    M, N = x.shape
    if M < N:
        raise ValueError('The input matrix of shape M x N must have M >= N.')
    if dynamic_shape is not None:
        m, n = dynamic_shape
        x = _mask(x, (m, n))
    else:
        m, n = M, N

    with jax.default_matmul_precision('float32'):
        u, h, num_iters, is_converged = _qdwh(x, m, n, is_hermitian,
                                              max_iterations, eps)

    return u, h, num_iters, is_converged
Exemple #27
0
def roots(p, *, strip_zeros=True):
    _check_arraylike("roots", p)
    p = atleast_1d(*_promote_dtypes_inexact(p))
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))

    if strip_zeros:
        num_leading_zeros = core.concrete_or_error(
            int, num_leading_zeros,
            "The error occurred in the jnp.roots() function. To use this within a "
            "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
            "will be result in some returned roots being set to NaN.")
        return _roots_no_zeros(p[num_leading_zeros:])
    else:
        return _roots_with_zeros(p, num_leading_zeros)
Exemple #28
0
def value_and_grad(fun: Callable,
                   argnums: Union[int, Sequence[int]] = 0,
                   **kwargs) -> Callable[..., Tuple[Any, Any]]:
  """Sparse-aware version of :func:`jax.value_and_grad`

  Arguments and return values are the same as :func:`jax.value_and_grad`, but when
  taking the gradient with respect to a BCOO matrix, the matrix indices are ignored.
  """
  # The approach here is to set allow_int=True (so that gradients of indices don't raise an error)
  # and then at the end replace the float0 outputs with the input indices.
  allow_int = kwargs.pop('allow_int', False)
  kwargs['allow_int'] = True
  raw_value_and_grad_fun = jax.value_and_grad(fun, argnums=argnums, **kwargs)
  argnums = core.concrete_or_error(_ensure_index, argnums)

  def maybe_copy_index(arg_in, arg_out):
    if isinstance(arg_in, BCOO) and isinstance(arg_out, BCOO):
      assert arg_in.indices.shape == arg_out.indices.shape
      return BCOO((arg_out.data, arg_in.indices), shape=arg_out.shape)
    else:
      return arg_out

  @wraps(fun, docstr=raw_value_and_grad_fun.__doc__, argnums=argnums)
  @api_boundary
  def value_and_grad_fun(*args, **kwargs):
    if not allow_int:
      dyn_args = [args[i] for i in _ensure_index_tuple(argnums)]
      dyn_args_flat, _ = tree_util.tree_flatten(dyn_args, is_leaf=lambda arg: isinstance(arg, BCOO))
      for arg in dyn_args_flat:
        dtype = np.dtype(arg)
        if not (np.issubdtype(arg, np.floating) or np.issubdtype(arg, np.complexfloating)):
          raise TypeError("grad requires real- or complex-valued inputs (input dtype that "
                          "is a sub-dtype of np.floating or np.complexfloating), "
                          f"but got {dtype.name}. If you want to use integer-valued "
                          "inputs, set allow_int to True.")
    value, grad = raw_value_and_grad_fun(*args, **kwargs)
    if isinstance(argnums, int):
      grad = maybe_copy_index(args[argnums], grad)
    else:
      grad = tuple(maybe_copy_index(args[argnum], g) for argnum, g in safe_zip(argnums, grad))
    return value, grad

  return value_and_grad_fun
Exemple #29
0
def bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=jnp.int32):
  """Create COO-format sparse matrix from a dense matrix.

  Args:
    mat : array to be converted to COO, with ``ndim = n_batch + n_sparse + n_dense``.
    nse : number of specified elements in each batch
    n_batch : number of batch dimensions (default: 0)
    n_dense : number of block_dimensions (default: 0)
    index_dtype : dtype of sparse indices (default: int32)

  Returns:
    data : array of shape ``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]``
      and dtype ``mat.dtype``
    indices : array of shape ``mat.shape[:n_batch] + (n_sparse, nse)``
  """
  mat = jnp.asarray(mat)
  if nse is None:
    nse = _bcoo_nse(mat, n_batch, n_dense)
  nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_fromdense")
  return bcoo_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
                               index_dtype=index_dtype)
Exemple #30
0
def _one_hot(x: Array, num_classes: int, *, dtype: Any,
             axis: Union[int, AxisName]) -> Array:
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)  # type: ignore[arg-type]
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
    return jnp.asarray(lhs == rhs, dtype=dtype)