def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None): _check_arraylike("unique", ar) if size is None: ar = core.concrete_or_error( None, ar, "The error arose for the first argument of jnp.unique(). " + UNIQUE_SIZE_HINT) else: size = core.concrete_or_error( operator.index, size, "The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT) ar = asarray(ar) if axis is None: axis = 0 ar = ar.flatten() axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()") return _unique(ar, axis, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None): _check_arraylike("setdiff1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()") else: size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()") ar1 = asarray(ar1) fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype) if ar1.size == 0: return full_like(ar1, fill_value, shape=size or 0) if not assume_unique: ar1 = unique(ar1, size=size and ar1.size) mask = in1d(ar1, ar2, invert=True) if size is None: return ar1[mask] else: if not (assume_unique or size is None): # Set mask to zero at locations corresponding to unique() padding. n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum() mask = where(arange(ar1.size) < n_unique, mask, False) return where( arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): _check_arraylike("intersect1d", ar1, ar2) ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()") if not assume_unique: if return_indices: ar1, ind1 = unique(ar1, return_index=True) ar2, ind2 = unique(ar2, return_index=True) else: ar1 = unique(ar1) ar2 = unique(ar2) else: ar1 = ravel(ar1) ar2 = ravel(ar2) if return_indices: aux, mask, aux_sort_indices = _intersect1d_sorted_mask( ar1, ar2, return_indices) else: aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices) int1d = aux[:-1][mask] if return_indices: ar1_indices = aux_sort_indices[:-1][mask] ar2_indices = aux_sort_indices[1:][mask] - ar1.size if not assume_unique: ar1_indices = ind1[ar1_indices] ar2_indices = ind2[ar2_indices] return int1d, ar1_indices, ar2_indices else: return int1d
def svd(a: Any, full_matrices: bool, compute_uv: bool = True, hermitian: bool = False, max_iterations: int = 10) -> Union[Any, Sequence[Any]]: """Singular value decomposition. Args: a: A matrix of shape `m x n`. full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, respectively. If False, the shapes are `m x k` and `k x n`, respectively, where `k = min(m, n)`. compute_uv: Whether to compute also `u` and `v` in addition to `s`. hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. Returns: A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices, `s` is vector of length `k` containing the singular values in the non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh` depend on the value of `full_matrices`. For `compute_uv=False`, only `s` is returned. """ full_matrices = core.concrete_or_error( bool, full_matrices, 'The `full_matrices` argument must be statically ' 'specified to use `svd` within JAX transformations.') compute_uv = core.concrete_or_error( bool, compute_uv, 'The `compute_uv` argument must be statically ' 'specified to use `svd` within JAX transformations.') hermitian = core.concrete_or_error( bool, hermitian, 'The `hermitian` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') max_iterations = core.concrete_or_error( int, max_iterations, 'The `max_iterations` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') # QDWH algorithm fails at zero-matrix `A` and produces all NaNs, which can # be seen from a dynamically weighted Halley (DWH) iteration: # X_{k+1} = X_k(a_k I + b_k {X_k}^H X_k)(I + c_k {X_k}^H X_k)^{−1} and # X_0 = A/alpha, where alpha = ||A||_2, the triplet (a_k, b_k, c_k) are # weighting parameters, and X_k denotes the k^{th} iterate. return jax.lax.cond(jnp.all(a == 0), functools.partial(_zero_svd, full_matrices=full_matrices, compute_uv=compute_uv), functools.partial(_qdwh_svd, full_matrices=full_matrices, compute_uv=compute_uv, hermitian=hermitian, max_iterations=max_iterations), operand=(a))
def union1d(ar1, ar2, *, size=None, fill_value=None): _check_arraylike("union1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()") else: size = core.concrete_or_error(operator.index, size, "The error arose in union1d()") return unique(concatenate((ar1, ar2), axis=None), size=size, fill_value=fill_value)
def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: r"""The associated Legendre functions (ALFs) of the first kind. Unlike `lpmn`, this function only computes the values of ALFs. The ALFs of the first kind can be used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as :math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`, where :math:`N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, repectively. :math:`N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow and achieves better numerical stability. Args: m: The maximum order of the associated Legendre functions. n: The maximum degree of the associated Legendre function, often called `l` in describing ALFs. Both the degrees and orders are `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree. z: A vector of type `float32` or `float64` containing the sampling points at which the ALFs are computed. is_normalized: True if the associated Legendre functions are normalized. With normalization, :math:`N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. Returns: A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing the values of the associated Legendre functions of the first kind. The return type matches the type of `z`. Raises: TypeError if elements of array `z` are not in (float32, float64). ValueError if array `z` is not 1D. NotImplementedError if `m!=n`. """ dtype = lax.dtype(z) if dtype not in (jnp.float32, jnp.float64): raise TypeError( 'z.dtype={} is not supported, see docstring for supported types.'. format(dtype)) if z.ndim != 1: raise ValueError('z must be a 1D array.') m = core.concrete_or_error(int, m, 'Argument m of lpmn.') n = core.concrete_or_error(int, n, 'Argument n of lpmn.') if m != n: raise NotImplementedError( 'Computations for m!=n are not yet supported.') l_max = n return _gen_associated_legendre(l_max, z, is_normalized)
def _make_1d_grid_from_slice(s: slice, op_name: str): start = core.concrete_or_error(None, s.start, f"slice start of jnp.{op_name}") or 0 stop = core.concrete_or_error(None, s.stop, f"slice stop of jnp.{op_name}") step = core.concrete_or_error(None, s.step, f"slice step of jnp.{op_name}") or 1 if np.iscomplex(step): newobj = linspace(start, stop, int(abs(step))) else: newobj = arange(start, stop, step) return newobj
def svd(a: jnp.ndarray, is_hermitian: bool = False, max_iterations: int = 10) -> Sequence[jnp.ndarray]: """Singular value decomposition. Args: a: A matrix of shape `m x n`. is_hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. Returns: A 3-tuple (`u`, `s`, `vh`), where `u` is a unitary matrix of shape `m x k`, `s` is vector of length `k` containing the singular values in the descending order, `vh` is a unitary matrix of shape `k x n`, `k = min(m, n)`, and `a = (u * s) @ vh`. """ is_hermitian = core.concrete_or_error( bool, is_hermitian, 'The `is_hermitian` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') max_iterations = core.concrete_or_error( int, max_iterations, 'The `max_iterations` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') m, n = a.shape is_flip = False if m < n: a = a.T.conj() m, n = a.shape is_flip = True reduce_to_square = False if m > 1.15 * n: m = n q, a = lax.linalg.qr(a, full_matrices=False) reduce_to_square = True u_out, s_out, v_out = _svd(a, is_hermitian, max_iterations) if reduce_to_square: u_out = q @ u_out if is_flip: return (v_out, s_out, u_out.T.conj()) return (u_out, s_out, v_out.T.conj())
def one_hot(x: Array, num_classes: int, *, dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`). axis: the axis or axes along which the function should be computed. """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") return _one_hot(x, num_classes, dtype=dtype, axis=axis)
def one_hot(x, num_classes, *, dtype=jnp.float64): """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) lhs = x[..., jnp.newaxis] rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim) return jnp.array(lhs == rhs, dtype=dtype)
def setxor1d(ar1, ar2, assume_unique=False): _check_arraylike("setxor1d", ar1, ar2) ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()") ar1 = ravel(ar1) ar2 = ravel(ar2) if not assume_unique: ar1 = unique(ar1) ar2 = unique(ar2) aux = concatenate((ar1, ar2)) if aux.size == 0: return aux aux = sort(aux) flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True]))) return aux[flag[1:] & flag[:-1]]
def lpmn(m: int, n: int, z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """The associated Legendre functions (ALFs) of the first kind. Args: m: The maximum order of the associated Legendre functions. n: The maximum degree of the associated Legendre function, often called `l` in describing ALFs. Both the degrees and orders are `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree. z: A vector of type `float32` or `float64` containing the sampling points at which the ALFs are computed. Returns: A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing the values and derivatives of the associated Legendre functions of the first kind. The return type matches the type of `z`. Raises: TypeError if elements of array `z` are not in (float32, float64). ValueError if array `z` is not 1D. NotImplementedError if `m!=n`. """ dtype = lax.dtype(z) if dtype not in (jnp.float32, jnp.float64): raise TypeError( 'z.dtype={} is not supported, see docstring for supported types.'. format(dtype)) if z.ndim != 1: raise ValueError('z must be a 1D array.') m = core.concrete_or_error(int, m, 'Argument m of lpmn.') n = core.concrete_or_error(int, n, 'Argument n of lpmn.') if m != n: raise NotImplementedError( 'Computations for m!=n are not yet supported.') l_max = n is_normalized = False p_vals = _gen_associated_legendre(l_max, z, is_normalized) p_derivatives = _gen_derivatives(p_vals, z, is_normalized) return (p_vals, p_derivatives)
def polyder(p, m=1): _check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") p, = _promote_dtypes_inexact(p) if m < 0: raise ValueError("Order of derivative must be positive") if m == 0: return p coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return p[:-m] * coeff
def _segment_update(name: str, data: Array, segment_ids: Array, scatter_op: Callable, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = core.concrete_or_error( int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: out = jnp.full((num_segments, ) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) return _scatter_update(out, segment_ids, data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. assert reducer is not None out = jnp.full((num_buckets, num_segments) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) out = _scatter_update( out, np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size), segment_ids[None, :]], data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) return reducer(out, axis=0).astype(dtype)
def qdwh(x, is_symmetric, max_iterations=10): """QR-based dynamically weighted Halley iteration for polar decomposition. Args: x: A full-rank matrix of shape `m x n` with `m >= n`. is_symmetric: True if `x` is symmetric. max_iterations: The predefined maximum number of iterations. Returns: A four-tuple of (u, h, num_iters, is_converged) containing the polar decomposition of `x = u * h`, the number of iterations to compute `u`, and `is_converged`, whose value is `True` when the convergence is achieved within the maximum number of iterations. """ m, n = x.shape if m < n: raise ValueError('The input matrix of shape m x n must have m >= n.') max_iterations = core.concrete_or_error( int, max_iterations, 'The `max_iterations` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') is_symmetric = core.concrete_or_error( bool, is_symmetric, 'The `is_symmetric` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') if is_symmetric: eps = jnp.finfo(x.dtype).eps tol = 50.0 * eps relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x) if relative_diff > tol: raise ValueError( 'The input `x` is NOT symmetric because ' '`norm(x-x.H) / norm(x)` is {}, which is greater than ' 'the tolerance {}.'.format(relative_diff, tol)) with jax.default_matmul_precision('float32'): u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations) return u, h, num_iters, is_converged
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d = _promote_args_inexact("multigammaln", a, d) constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d), lax.sub(d, _constant_like(a, 1))), lax.log(_constant_like(a, np.pi))) res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) - lax.div(jnp.arange(d), _constant_like(a, 2))), axis=-1) return res + constant
def _ensure_optional_axes(x): def force(x): if x is None: return None try: return operator.index(x) except TypeError: return tuple(i if isinstance(i, str) else operator.index(i) for i in x) return core.concrete_or_error( force, x, "The axis argument must be known statically.")
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = _promote_args_inexact("multigammaln", a, d) constant = lax.mul( lax.mul(lax.mul(_lax_const(a, 0.25), d_), lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi))) b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2)) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=tuple(range(a.ndim)))), axis=-1) return res + constant
def one_hot(x: Array, num_classes: int, *, dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). axis: the axis or axes along which the function should be computed. """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype), rhs_shape, (output_pos_axis, )) return jnp.asarray(lhs == rhs, dtype=dtype)
def _coo_fromdense(mat, *, nse, index_dtype=jnp.int32): """Create COO-format sparse matrix from a dense matrix. Args: mat : array to be converted to COO. nse : number of specified entries in ``mat`` index_dtype : dtype of sparse indices Returns: data : array of shape ``(nse,)`` and dtype ``mat.dtype`` row : array of shape ``(nse,)`` and dtype ``index_dtype`` col : array of shape ``(nse,)`` and dtype ``index_dtype`` """ mat = jnp.asarray(mat) nse = core.concrete_or_error(operator.index, nse, "nse argument of coo_fromdense()") return coo_fromdense_p.bind(mat, nse=nse, index_dtype=index_dtype)
def csr_fromdense(mat, *, nnz, index_dtype=np.int32): """Create CSR-format sparse matrix from a dense matrix. Args: mat : array to be converted to CSR. nnz : number of nonzero entries in ``mat`` index_dtype : dtype of sparse indices Returns: data : array of shape ``(nnz,)`` and dtype ``mat.dtype``. indices : array of shape ``(nnz,)`` and dtype ``index_dtype`` indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype`` """ mat = jnp.asarray(mat) nnz = core.concrete_or_error(operator.index, nnz, "nnz argument of csr_fromdense()") return csr_fromdense_p.bind(mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
def sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray, phi: jnp.ndarray, n_max: Optional[int] = None) -> jnp.ndarray: r"""Computes the spherical harmonics. The JAX version has one extra argument `n_max`, the maximum value in `n`. The spherical harmonic of degree `n` and order `m` can be written as :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`, where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and :math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. Args: m: The order of the harmonic; must have `|m| <= n`. Return values for `|m| > n` ara undefined. n: The degree of the harmonic; must have `n >= 0`. The standard notation for degree in descriptions of spherical harmonics is `l (lower case L)`. We use `n` here to be consistent with `scipy.special.sph_harm`. Return values for `n < 0` are undefined. theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi]. phi: The polar (colatitudinal) coordinate; must be in [0, pi]. n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true maximum value of `n`, the results are clipped to `n_max`. For example, `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)` acutually returns `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)` Returns: A 1D array containing the spherical harmonics at (m, n, theta, phi). """ if jnp.isscalar(phi): phi = jnp.array([phi]) if n_max is None: n_max = jnp.max(n) n_max = core.concrete_or_error( int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must ' 'be statically specified to use `sph_harm` within JAX transformations.' ) return _sph_harm(m, n, theta, phi, n_max)
def coo_fromdense(mat, *, nse=None, index_dtype=jnp.int32): """Create a COO-format sparse matrix from a dense matrix. Args: mat : array to be converted to COO. nse : number of specified entries in ``mat``. If not specified, it will be computed from the input matrix. index_dtype : dtype of sparse indices Returns: mat_coo : COO representation of the matrix. """ if nse is None: nse = (mat != 0).sum() nse = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument") return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape, rows_sorted=True)
def _segment_update(name: str, data: Array, segment_ids: Array, scatter_op: Callable, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = core.concrete_or_error( int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") out = jnp.full((num_segments, ) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: return _scatter_update(out, segment_ids, data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. assert reducer is not None outs = [] for sub_data, sub_segment_ids in zip( jnp.array_split(data, num_buckets), jnp.array_split(segment_ids, num_buckets)): outs.append( _segment_update(name, sub_data, sub_segment_ids, scatter_op, num_segments, indices_are_sorted, unique_indices)) return reducer(jnp.stack(outs), axis=0).astype(dtype)
def polyint(p, m=1, k=None): m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k _check_arraylike("polyint", p, k) p, k = _promote_dtypes_inexact(p, k) if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k = atleast_1d(k) if len(k) == 1: k = full((m,), k[0]) if k.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: return p else: coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return true_divide(concatenate((p, k)), coeff)
def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None, dynamic_shape: Optional[Tuple[int, int]] = None): """QR-based dynamically weighted Halley iteration for polar decomposition. Args: x: A full-rank matrix, with shape `M x N`. The matrix may be padded up to that size from a smaller true shape (``dynamic_shape``). is_hermitian: True if `x` is Hermitian. Default to `False`. eps: The final result will satisfy ``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate. max_iterations: Iterations will terminate after this many steps even if the above is unsatisfied. dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional. Returns: A four-tuple of (u, h, num_iters, is_converged) containing the polar decomposition of `x = u * h`, the number of iterations to compute `u`, and `is_converged`, whose value is `True` when the convergence is achieved within the maximum number of iterations. """ is_hermitian = core.concrete_or_error( bool, is_hermitian, 'The `is_hermitian` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') if max_iterations is None: max_iterations = 10 M, N = x.shape if M < N: raise ValueError('The input matrix of shape M x N must have M >= N.') if dynamic_shape is not None: m, n = dynamic_shape x = _mask(x, (m, n)) else: m, n = M, N with jax.default_matmul_precision('float32'): u, h, num_iters, is_converged = _qdwh(x, m, n, is_hermitian, max_iterations, eps) return u, h, num_iters, is_converged
def roots(p, *, strip_zeros=True): _check_arraylike("roots", p) p = atleast_1d(*_promote_dtypes_inexact(p)) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p.size < 2: return array([], dtype=dtypes._to_complex_dtype(p.dtype)) num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) if strip_zeros: num_leading_zeros = core.concrete_or_error( int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " "will be result in some returned roots being set to NaN.") return _roots_no_zeros(p[num_leading_zeros:]) else: return _roots_with_zeros(p, num_leading_zeros)
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, **kwargs) -> Callable[..., Tuple[Any, Any]]: """Sparse-aware version of :func:`jax.value_and_grad` Arguments and return values are the same as :func:`jax.value_and_grad`, but when taking the gradient with respect to a BCOO matrix, the matrix indices are ignored. """ # The approach here is to set allow_int=True (so that gradients of indices don't raise an error) # and then at the end replace the float0 outputs with the input indices. allow_int = kwargs.pop('allow_int', False) kwargs['allow_int'] = True raw_value_and_grad_fun = jax.value_and_grad(fun, argnums=argnums, **kwargs) argnums = core.concrete_or_error(_ensure_index, argnums) def maybe_copy_index(arg_in, arg_out): if isinstance(arg_in, BCOO) and isinstance(arg_out, BCOO): assert arg_in.indices.shape == arg_out.indices.shape return BCOO((arg_out.data, arg_in.indices), shape=arg_out.shape) else: return arg_out @wraps(fun, docstr=raw_value_and_grad_fun.__doc__, argnums=argnums) @api_boundary def value_and_grad_fun(*args, **kwargs): if not allow_int: dyn_args = [args[i] for i in _ensure_index_tuple(argnums)] dyn_args_flat, _ = tree_util.tree_flatten(dyn_args, is_leaf=lambda arg: isinstance(arg, BCOO)) for arg in dyn_args_flat: dtype = np.dtype(arg) if not (np.issubdtype(arg, np.floating) or np.issubdtype(arg, np.complexfloating)): raise TypeError("grad requires real- or complex-valued inputs (input dtype that " "is a sub-dtype of np.floating or np.complexfloating), " f"but got {dtype.name}. If you want to use integer-valued " "inputs, set allow_int to True.") value, grad = raw_value_and_grad_fun(*args, **kwargs) if isinstance(argnums, int): grad = maybe_copy_index(args[argnums], grad) else: grad = tuple(maybe_copy_index(args[argnum], g) for argnum, g in safe_zip(argnums, grad)) return value, grad return value_and_grad_fun
def bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=jnp.int32): """Create COO-format sparse matrix from a dense matrix. Args: mat : array to be converted to COO, with ``ndim = n_batch + n_sparse + n_dense``. nse : number of specified elements in each batch n_batch : number of batch dimensions (default: 0) n_dense : number of block_dimensions (default: 0) index_dtype : dtype of sparse indices (default: int32) Returns: data : array of shape ``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]`` and dtype ``mat.dtype`` indices : array of shape ``mat.shape[:n_batch] + (n_sparse, nse)`` """ mat = jnp.asarray(mat) if nse is None: nse = _bcoo_nse(mat, n_batch, n_dense) nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_fromdense") return bcoo_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense, index_dtype=index_dtype)
def _one_hot(x: Array, num_classes: int, *, dtype: Any, axis: Union[int, AxisName]) -> Array: num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) # type: ignore[arg-type] lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis) return jnp.asarray(lhs == rhs, dtype=dtype)