def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None): _check_arraylike("setdiff1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()") else: size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()") ar1 = asarray(ar1) fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype) if ar1.size == 0: return full_like(ar1, fill_value, shape=size or 0) if not assume_unique: ar1 = unique(ar1, size=size and ar1.size) mask = in1d(ar1, ar2, invert=True) if size is None: return ar1[mask] else: if not (assume_unique or size is None): # Set mask to zero at locations corresponding to unique() padding. n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum() mask = where(arange(ar1.size) < n_unique, mask, False) return where( arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
def cond(x, p=None): _assertNoEmpty2d(x) if p in (None, 2): s = la.svd(x, compute_uv=False) return s[..., 0] / s[..., -1] elif p == -2: s = la.svd(x, compute_uv=False) r = s[..., -1] / s[..., 0] else: _assertRankAtLeast2(x) _assertNdSquareness(x) invx = la.inv(x) r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm( invx, ord=p, axis=(-2, -1)) # Convert nans to infs unless the original array had nan entries orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any()) nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1))) r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r) return r
def _nan(args): A, *_ = args return jnp.full_like(A, jnp.nan)
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False, size=None, fill_value=None, return_true_size=False): """ Find the unique elements of an array along a particular axis. """ if ar.shape[axis] == 0 and size and fill_value is None: raise ValueError( "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified" ) aux, mask, perm = _unique_sorted_mask(ar, axis) if size is None: ind = core.concrete_or_error( None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT) else: ind = nonzero(mask, size=size)[0] result = aux[ind] if aux.size else aux if fill_value is not None: fill_value = asarray(fill_value, dtype=result.dtype) if size is not None and fill_value is not None: if result.shape[0]: valid = lax.expand_dims( arange(size) < mask.sum(), tuple(range(1, result.ndim))) result = where(valid, result, fill_value) else: result = full_like(result, fill_value, shape=(size, *result.shape[1:])) result = moveaxis(result, 0, axis) ret = (result, ) if return_index: if aux.size: ret += (perm[ind], ) else: ret += (perm, ) if return_inverse: if aux.size: imask = cumsum(mask) - 1 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(dtypes.int_)) inv_idx = inv_idx.at[perm].set(imask) else: inv_idx = zeros(ar.shape[axis], dtype=int) ret += (inv_idx, ) if return_counts: if aux.size: if size is None: idx = append(nonzero(mask)[0], mask.size) else: idx = nonzero(mask, size=size + 1)[0] idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size)) ret += (diff(idx), ) elif ar.shape[axis]: ret += (array([ar.shape[axis]], dtype=dtypes.canonicalize_dtype(dtypes.int_)), ) else: ret += (empty(0, dtype=int), ) if return_true_size: # Useful for internal uses of unique(). ret += (mask.sum(), ) return ret[0] if len(ret) == 1 else ret