def arccosh(x): # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different # convention than np.arccosh. out = lax.acosh(*_promote_args_inexact("arccosh", x)) if dtypes.issubdtype(out.dtype, np.complexfloating): out = _where(real(out) < 0, lax.neg(out), out) return out
def _logaddexp2_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) primal_out = logaddexp2(x1, x2) tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): if promote_to_inexact: fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) else: fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() return _wraps(numpy_fn, lax_description=doc)(fn) else: return _wraps(numpy_fn)(fn)
def logaddexp(x1, x2): x1, x2 = _promote_args_inexact("logaddexp", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax_internal._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def rad2deg(x): x, = _promote_args_inexact("rad2deg", x) return lax.mul(x, _lax_const(x, 180 / np.pi))
def deg2rad(x): x, = _promote_args_inexact("deg2rad", x) return lax.mul(x, _lax_const(x, np.pi / 180))
def exp2(x): x, = _promote_args_inexact("exp2", x) return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
def log10(x): x, = _promote_args_inexact("log10", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
def true_divide(x1, x2): x1, x2 = _promote_args_inexact("true_divide", x1, x2) return lax.div(x1, x2)
def copysign(x1, x2): x1, x2 = _promote_args_inexact("copysign", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))