コード例 #1
0
def remainder(x1, x2):
    x1, x2 = _promote_args("remainder", x1, x2)
    zero = _constant_like(x1, 0)
    trunc_mod = lax.rem(x1, x2)
    trunc_mod_not_zero = lax.ne(trunc_mod, zero)
    do_plus = lax.bitwise_and(
        lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
    return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
コード例 #2
0
def where(condition, x=None, y=None):
    if x is None or y is None:
        raise ValueError("Must use the three-argument form of where().")
    if not onp.issubdtype(_dtype(condition), onp.bool_):
        condition = lax.ne(condition, zeros_like(condition))
    condition, x, y = broadcast_arrays(condition, x, y)
    return lax.select(condition, *_promote_dtypes(x, y))
コード例 #3
0
def _unique_sorted_mask(ar, axis):
    aux = moveaxis(ar, axis, 0)
    if np.issubdtype(aux.dtype, np.complexfloating):
        # Work around issue in sorting of complex numbers with Nan only in the
        # imaginary component. This can be removed if sorting in this situation
        # is fixed to match numpy.
        aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
    size, *out_shape = aux.shape
    if _prod(out_shape) == 0:
        size = 1
        perm = zeros(1, dtype=int)
    else:
        perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
    aux = aux[perm]
    if aux.size:
        if dtypes.issubdtype(aux.dtype, np.inexact):
            # This is appropriate for both float and complex due to the documented behavior of np.unique:
            # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
            neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
        else:
            neq = lax.ne
        mask = ones(size, dtype=bool).at[1:].set(
            any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
    else:
        mask = zeros(size, dtype=bool)
    return aux, mask, perm
コード例 #4
0
def count_nonzero(a,
                  axis: Optional[Union[int, Tuple[int, ...]]] = None,
                  keepdims=False):
    _check_arraylike("count_nonzero", a)
    return sum(lax.ne(a, _lax_const(a, 0)),
               axis=axis,
               dtype=dtypes.canonicalize_dtype(np.int_),
               keepdims=keepdims)
コード例 #5
0
def _power(x1, x2):
  x1, x2 = _promote_args("power", x1, x2)
  dtype = dtypes.dtype(x1)
  if not dtypes.issubdtype(dtype, np.integer):
    return lax.pow(x1, x2)

  # Integer power => use binary exponentiation.

  # TODO(phawkins): add integer pow support to XLA.
  bits = 6  # Anything more would overflow for any x1 > 1
  zero = _constant_like(x2, 0)
  one = _constant_like(x2, 1)
  # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
  acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
  for _ in range(bits):
    acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
    x1 = lax.mul(x1, x1)
    x2 = lax.shift_right_logical(x2, one)
  return acc
コード例 #6
0
def isnan(x):
    _check_arraylike("isnan", x)
    return lax.ne(x, x)
コード例 #7
0
 def op(*args):
     zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
     args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(
         x, zero(x)) for x in args)
     return bitwise_op(*_promote_args(np_op.__name__, *args))