def round(a, decimals=0): if onp.issubdtype(_dtype(a), onp.integer): return a # no-op on integer types if decimals == 0: return lax.round(a) factor = _constant_like(a, 10**decimals) return lax.div(lax.round(lax.mul(a, factor)), factor)
def rint(x): _check_arraylike('rint', x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) if dtypes.issubdtype(dtype, np.complexfloating): return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
def _float_divmod(x1, x2): # see float_divmod in floatobject.c of CPython mod = lax.rem(x1, x2) div = lax.div(lax.sub(x1, mod), x2) ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) mod = lax.select(ind, mod + x2, mod) div = lax.select(ind, div - _constant_like(div, 1), div) return lax.round(div), mod
def _round_half_away_from_zero(a): return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)