コード例 #1
0
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
    _check_arraylike("setdiff1d", ar1, ar2)
    if size is None:
        ar1 = core.concrete_or_error(None, ar1,
                                     "The error arose in setdiff1d()")
    else:
        size = core.concrete_or_error(operator.index, size,
                                      "The error arose in setdiff1d()")
    ar1 = asarray(ar1)
    fill_value = asarray(0 if fill_value is None else fill_value,
                         dtype=ar1.dtype)
    if ar1.size == 0:
        return full_like(ar1, fill_value, shape=size or 0)
    if not assume_unique:
        ar1 = unique(ar1, size=size and ar1.size)
    mask = in1d(ar1, ar2, invert=True)
    if size is None:
        return ar1[mask]
    else:
        if not (assume_unique or size is None):
            # Set mask to zero at locations corresponding to unique() padding.
            n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum()
            mask = where(arange(ar1.size) < n_unique, mask, False)
        return where(
            arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
コード例 #2
0
ファイル: linalg.py プロジェクト: yashk2810/jax
def cond(x, p=None):
    _assertNoEmpty2d(x)
    if p in (None, 2):
        s = la.svd(x, compute_uv=False)
        return s[..., 0] / s[..., -1]
    elif p == -2:
        s = la.svd(x, compute_uv=False)
        r = s[..., -1] / s[..., 0]
    else:
        _assertRankAtLeast2(x)
        _assertNdSquareness(x)
        invx = la.inv(x)
        r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(
            invx, ord=p, axis=(-2, -1))

    # Convert nans to infs unless the original array had nan entries
    orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any())
    nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
    r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
    return r
コード例 #3
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
 def _nan(args):
     A, *_ = args
     return jnp.full_like(A, jnp.nan)
コード例 #4
0
def _unique(ar,
            axis,
            return_index=False,
            return_inverse=False,
            return_counts=False,
            size=None,
            fill_value=None,
            return_true_size=False):
    """
  Find the unique elements of an array along a particular axis.
  """
    if ar.shape[axis] == 0 and size and fill_value is None:
        raise ValueError(
            "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified"
        )

    aux, mask, perm = _unique_sorted_mask(ar, axis)
    if size is None:
        ind = core.concrete_or_error(
            None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT)
    else:
        ind = nonzero(mask, size=size)[0]
    result = aux[ind] if aux.size else aux
    if fill_value is not None:
        fill_value = asarray(fill_value, dtype=result.dtype)
    if size is not None and fill_value is not None:
        if result.shape[0]:
            valid = lax.expand_dims(
                arange(size) < mask.sum(), tuple(range(1, result.ndim)))
            result = where(valid, result, fill_value)
        else:
            result = full_like(result,
                               fill_value,
                               shape=(size, *result.shape[1:]))
    result = moveaxis(result, 0, axis)

    ret = (result, )
    if return_index:
        if aux.size:
            ret += (perm[ind], )
        else:
            ret += (perm, )
    if return_inverse:
        if aux.size:
            imask = cumsum(mask) - 1
            inv_idx = zeros(mask.shape,
                            dtype=dtypes.canonicalize_dtype(dtypes.int_))
            inv_idx = inv_idx.at[perm].set(imask)
        else:
            inv_idx = zeros(ar.shape[axis], dtype=int)
        ret += (inv_idx, )
    if return_counts:
        if aux.size:
            if size is None:
                idx = append(nonzero(mask)[0], mask.size)
            else:
                idx = nonzero(mask, size=size + 1)[0]
                idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
            ret += (diff(idx), )
        elif ar.shape[axis]:
            ret += (array([ar.shape[axis]],
                          dtype=dtypes.canonicalize_dtype(dtypes.int_)), )
        else:
            ret += (empty(0, dtype=int), )
    if return_true_size:
        # Useful for internal uses of unique().
        ret += (mask.sum(), )
    return ret[0] if len(ret) == 1 else ret