コード例 #1
0
def arccosh(x):
    # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
    # convention than np.arccosh.
    out = lax.acosh(*_promote_args_inexact("arccosh", x))
    if dtypes.issubdtype(out.dtype, np.complexfloating):
        out = _where(real(out) < 0, lax.neg(out), out)
    return out
コード例 #2
0
def _logaddexp2_jvp(primals, tangents):
  x1, x2 = primals
  t1, t2 = tangents
  x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
  primal_out = logaddexp2(x1, x2)
  tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
                        lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
  return primal_out, tangent_out
コード例 #3
0
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
  if promote_to_inexact:
    fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
  else:
    fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
    return _wraps(numpy_fn, lax_description=doc)(fn)
  else:
    return _wraps(numpy_fn)(fn)
コード例 #4
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
コード例 #5
0
def rad2deg(x):
    x, = _promote_args_inexact("rad2deg", x)
    return lax.mul(x, _lax_const(x, 180 / np.pi))
コード例 #6
0
def deg2rad(x):
    x, = _promote_args_inexact("deg2rad", x)
    return lax.mul(x, _lax_const(x, np.pi / 180))
コード例 #7
0
def exp2(x):
    x, = _promote_args_inexact("exp2", x)
    return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
コード例 #8
0
def log10(x):
    x, = _promote_args_inexact("log10", x)
    return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
コード例 #9
0
def true_divide(x1, x2):
    x1, x2 = _promote_args_inexact("true_divide", x1, x2)
    return lax.div(x1, x2)
コード例 #10
0
def copysign(x1, x2):
    x1, x2 = _promote_args_inexact("copysign", x1, x2)
    if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
        raise TypeError("copysign does not support complex-valued inputs")
    return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))