Example #1
0
def remainder(x1, x2):
    x1, x2 = _promote_args("remainder", x1, x2)
    zero = _constant_like(x1, 0)
    trunc_mod = lax.rem(x1, x2)
    trunc_mod_not_zero = lax.ne(trunc_mod, zero)
    do_plus = lax.bitwise_and(
        lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
    return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
Example #2
0
def _wrap_between(x, _a):
    """Wraps `x` between `[-a, a]`."""
    a = _constant_like(x, _a)
    two_a = _constant_like(x, 2 * _a)
    zero = _constant_like(x, 0)
    rem = lax.rem(lax.add(x, a), two_a)
    rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
    return lax.sub(rem, a)
Example #3
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  if onp.issubdtype(_dtype(x1), onp.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return where(select, quotient - onp.array(1, _dtype(quotient)), quotient)
  else:
    return _float_divmod(x1, x2)[0]
Example #4
0
def _float_divmod(x1, x2):
    # see float_divmod in floatobject.c of CPython
    mod = lax.rem(x1, x2)
    div = lax.div(lax.sub(x1, mod), x2)

    ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
    mod = lax.select(ind, mod + x2, mod)
    div = lax.select(ind, div - _constant_like(div, 1), div)

    return lax.round(div), mod
Example #5
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
Example #6
0
def fmod(x1, x2):
    _check_arraylike("fmod", x1, x2)
    if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
        x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
    return lax.rem(*_promote_args("fmod", x1, x2))
Example #7
0
def remainder(x1, x2):
    x1, x2 = _promote_args("remainder", x1, x2)
    return lax.rem(lax.add(lax.rem(x1, x2), x2), x2)