def remainder(x1, x2): x1, x2 = _promote_args("remainder", x1, x2) zero = _constant_like(x1, 0) trunc_mod = lax.rem(x1, x2) trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
def where(condition, x=None, y=None): if x is None or y is None: raise ValueError("Must use the three-argument form of where().") if not onp.issubdtype(_dtype(condition), onp.bool_): condition = lax.ne(condition, zeros_like(condition)) condition, x, y = broadcast_arrays(condition, x, y) return lax.select(condition, *_promote_dtypes(x, y))
def _unique_sorted_mask(ar, axis): aux = moveaxis(ar, axis, 0) if np.issubdtype(aux.dtype, np.complexfloating): # Work around issue in sorting of complex numbers with Nan only in the # imaginary component. This can be removed if sorting in this situation # is fixed to match numpy. aux = where(isnan(aux), _lax_const(aux, np.nan), aux) size, *out_shape = aux.shape if _prod(out_shape) == 0: size = 1 perm = zeros(1, dtype=int) else: perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1]) aux = aux[perm] if aux.size: if dtypes.issubdtype(aux.dtype, np.inexact): # This is appropriate for both float and complex due to the documented behavior of np.unique: # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220 neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y)) else: neq = lax.ne mask = ones(size, dtype=bool).at[1:].set( any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim)))) else: mask = zeros(size, dtype=bool) return aux, mask, perm
def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims=False): _check_arraylike("count_nonzero", a) return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
def _power(x1, x2): x1, x2 = _promote_args("power", x1, x2) dtype = dtypes.dtype(x1) if not dtypes.issubdtype(dtype, np.integer): return lax.pow(x1, x2) # Integer power => use binary exponentiation. # TODO(phawkins): add integer pow support to XLA. bits = 6 # Anything more would overflow for any x1 > 1 zero = _constant_like(x2, 0) one = _constant_like(x2, 1) # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) for _ in range(bits): acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) x1 = lax.mul(x1, x1) x2 = lax.shift_right_logical(x2, one) return acc
def isnan(x): _check_arraylike("isnan", x) return lax.ne(x, x)
def op(*args): zero = lambda x: lax.full_like(x, shape=(), fill_value=0) args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne( x, zero(x)) for x in args) return bitwise_op(*_promote_args(np_op.__name__, *args))