예제 #1
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
예제 #2
0
파일: special.py 프로젝트: GregCT/jax
def xlog1py(x, y):
    x, y = _promote_args_inexact("xlog1py", x, y)
    x_ok = x != 0.
    safe_x = jnp.where(x_ok, x, 1.)
    safe_y = jnp.where(x_ok, y, 1.)
    return jnp.where(x_ok, lax.mul(safe_x, lax.log1p(safe_y)),
                     jnp.zeros_like(x))
예제 #3
0
파일: cauchy.py 프로젝트: yashk2810/jax
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = _constant_like(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
예제 #4
0
def log1m_exp(val):
    """Numerically stable implementation of `log(1 - exp(val))`."""
    return lax.cond(
        lax.gt(val, lax.log(2.0)),
        lambda _: lax.log(-lax.expm1(val)),
        lambda _: lax.log1p(-lax.exp(val)),
        operand=None,
    )
예제 #5
0
def _logaddexp(x1, x2):
  """
  Logaddexp while ignoring the custom_jvp rule.
  """
  amax = lax.max(x1, x2)
  delta = lax.sub(x1, x2)
  return lax.select(jnp.isnan(delta),
                    lax.add(x1, x2),  # NaNs or infinities of the same sign.
                    lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
예제 #6
0
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = _constant_like(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(lax.mul(lax.sub(a, one), lax.log(y)),
                              lax.mul(lax.sub(b, one), lax.log1p(lax.neg(y))))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
예제 #7
0
def logpdf(x, df, loc=0, scale=1):
  x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
  two = _lax_const(x, 2)
  scaled_x = lax.div(lax.sub(x, loc), scale)
  df_over_two = lax.div(df, two)
  df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
  normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
  normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
  normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
                           lax.lgamma(df_plus_one_over_two))
  quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
  return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
예제 #8
0
def logpmf(k, n, a, b, loc=0):
    """JAX implementation of scipy.stats.betabinom.logpmf."""
    k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b,
                                            loc)
    y = lax.sub(lax.floor(k), loc)
    one = _lax_const(y, 1)
    zero = _lax_const(y, 0)
    combiln = lax.neg(
        lax.add(lax.log1p(n),
                betaln(lax.add(lax.sub(n, y), one), lax.add(y, one))))
    beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)),
                       betaln(a, b))
    log_probs = lax.add(combiln, beta_lns)
    y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
    log_probs = where(y_cond, -inf, log_probs)
    n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)),
                            lax.lt(b, zero))
    return where(n_a_b_cond, nan, log_probs)
예제 #9
0
def xlog1py_jvp_lhs(g, x, y, jaxpr, aval, consts):
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
    g, y = _promote_args_like(osp_special.xlog1py, g, y)
    return lax._safe_mul(lax._brcast(g, y), lax._brcast(lax.log1p(y), g))
예제 #10
0
파일: logistic.py 프로젝트: yashk2810/jax
def logpdf(x):
    return lax.neg(x) - 2. * lax.log1p(lax.exp(lax.neg(x)))