예제 #1
0
def round(a, decimals=0):
    if onp.issubdtype(_dtype(a), onp.integer):
        return a  # no-op on integer types

    if decimals == 0:
        return lax.round(a)

    factor = _constant_like(a, 10**decimals)
    return lax.div(lax.round(lax.mul(a, factor)), factor)
예제 #2
0
def rint(x):
    _check_arraylike('rint', x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.convert_element_type(x, dtypes.float_)
    if dtypes.issubdtype(dtype, np.complexfloating):
        return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
    return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
예제 #3
0
def _float_divmod(x1, x2):
    # see float_divmod in floatobject.c of CPython
    mod = lax.rem(x1, x2)
    div = lax.div(lax.sub(x1, mod), x2)

    ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
    mod = lax.select(ind, mod + x2, mod)
    div = lax.select(ind, div - _constant_like(div, 1), div)

    return lax.round(div), mod
예제 #4
0
파일: ndimage.py 프로젝트: gnecula/jax
def _round_half_away_from_zero(a):
    return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)