Example #1
0
File: beta.py Project: 0x0is1/jax
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = lax._const(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Example #2
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    m, n = b.shape[-2:]
    k = 1 if unit_diagonal else 0
    g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = np.conj(g_a) if conjugate_a else g_a
    dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)

    def a_inverse(rhs):
        return triangular_solve(a, rhs, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal)

    # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
    # for matrix/vector inputs). Order these operations in whichever order is
    # cheaper.
    if left_side:
        assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (
            m, n)
        if m > n:
            return a_inverse(dot(g_a, ans))  # A^{-1} (∂A X)
        else:
            return dot(a_inverse(g_a), ans)  # (A^{-1} ∂A) X
    else:
        assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (
            m, n)
        if m < n:
            return a_inverse(dot(ans, g_a))  # (X ∂A) A^{-1}
        else:
            return dot(ans, a_inverse(g_a))  # X (∂A A^{-1})
Example #3
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    scale_sqrd = lax.pow(scale, two)
    log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd))
    quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
    return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
Example #4
0
def arccosh(x):
    # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
    # convention than np.arccosh.
    out = lax.acosh(*_promote_args_inexact("arccosh", x))
    if dtypes.issubdtype(out.dtype, np.complexfloating):
        out = _where(real(out) < 0, lax.neg(out), out)
    return out
Example #5
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = _constant_like(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
Example #6
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = _constant_like(x, 0.5)
    one = _constant_like(x, 1)
    zero = _constant_like(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
Example #7
0
def logpdf(x, b, loc=0, scale=1):
    x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
    one = _constant_like(x, 1)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.div(scale, b))
    log_probs = lax.neg(
        lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
    return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
Example #8
0
def logpmf(k, n, a, b, loc=0):
    """JAX implementation of scipy.stats.betabinom.logpmf."""
    k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b,
                                            loc)
    y = lax.sub(lax.floor(k), loc)
    one = _lax_const(y, 1)
    zero = _lax_const(y, 0)
    combiln = lax.neg(
        lax.add(lax.log1p(n),
                betaln(lax.add(lax.sub(n, y), one), lax.add(y, one))))
    beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)),
                       betaln(a, b))
    log_probs = lax.add(combiln, beta_lns)
    y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
    log_probs = where(y_cond, -inf, log_probs)
    n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)),
                            lax.lt(b, zero))
    return where(n_a_b_cond, nan, log_probs)
Example #9
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a):
    g_a = lax.neg(g_a)
    g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
    tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a)
    dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul
    if left_side:
        return dot(tmp, ans)
    else:
        return dot(ans, tmp)
Example #10
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
Example #11
0
def logpdf(x, df, loc=0, scale=1):
  x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
  two = _lax_const(x, 2)
  scaled_x = lax.div(lax.sub(x, loc), scale)
  df_over_two = lax.div(df, two)
  df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
  normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
  normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
  normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
                           lax.lgamma(df_plus_one_over_two))
  quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
  return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
Example #12
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _constant_like(x, 1)
    two = _constant_like(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

    nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
Example #13
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    k = 1 if unit_diagonal else 0
    g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = np.conj(g_a) if conjugate_a else g_a
    tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a,
                           unit_diagonal)
    dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul
    if left_side:
        return dot(tmp, ans)
    else:
        return dot(ans, tmp)
Example #14
0
File: poisson.py Project: lmmx/mcx
def _random_poisson(rng_key, lmbda, shape):
    """
    References
    ----------
    .. [1] Knuth, Donald E. Art of computer programming, volume 2:
           Seminumerical algorithms. Addison-Wesley Professional, 2014 (p 137).
    """
    L = lax.exp(lax.neg(lmbda))
    k = np.zeros(shape=shape)
    p = np.ones(shape=shape)

    is_done = p < L
    while not is_done.all():
        _, rng_key = random.split(rng_key)
        u = random.uniform(rng_key, shape=shape)
        p = np.where(is_done, p, u * p)
        k = np.where(is_done, k, k + 1)
        is_done = p < L

    return k
Example #15
0
File: jet.py Project: 0x0is1/jax
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
Example #16
0
def eager_unary(state):
    a = jax.device_put(1)
    lax.neg(a).block_until_ready()
    while state:
        lax.neg(a).block_until_ready()
Example #17
0
def eager_unary_dispatch(state):
    a = jax.device_put(1)
    lax.neg(a)
    while state:
        lax.neg(a)
Example #18
0
def expit(x):
    x = asarray(x)
    one = lax._const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Example #19
0
def expn_jvp(n, primals, tangents):
    (x, ), (x_dot, ) = primals, tangents
    return expn(n, x), lax.mul(lax.neg(x_dot),
                               expn(lax.sub(n, _lax_const(n, 1)), x))
Example #20
0
def logpdf(x):
    x, = _promote_args_inexact("logistic.logpdf", x)
    two = _lax_const(x, 2)
    half_x = lax.div(x, two)
    return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
Example #21
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("expon.logpdf", x, loc, scale)
    log_scale = lax.log(scale)
    linear_term = lax.div(lax.sub(x, loc), scale)
    log_probs = lax.neg(lax.add(linear_term, log_scale))
    return where(lax.lt(x, loc), -inf, log_probs)
Example #22
0
def sf(x):
    return expit(lax.neg(x))
Example #23
0
def sinh(x):
    x, = _promote_to_result_dtype(onp.sinh, x)
    return lax.div(lax.sub(lax.exp(x), lax.exp(lax.neg(x))),
                   _constant_like(x, 2))
Example #24
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
    return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
Example #25
0
 def fun(x, y, z):
   pred = lax.lt(x, 3)
   true_fun = lambda y: y
   false_fun = lambda z: lax.neg(z)
   return lax.cond(pred, y, true_fun, z, false_fun)
Example #26
0
def cosh(x):
    x, = _promote_to_result_dtype(onp.cosh, x)
    return lax.div(lax.add(lax.exp(x), lax.exp(lax.neg(x))),
                   _constant_like(x, 2))
Example #27
0
def entr(x):
    x, = _promote_args_inexact("entr", x)
    return lax.select(lax.lt(x, _constant_like(x, 0)),
                      lax.full_like(x, -np.inf), lax.neg(xlogy(x, x)))
Example #28
0
def expit(x):
    x, = _promote_args_inexact("expit", x)
    one = _lax_const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Example #29
0
def deriv_prop(prim, deriv, primals_in, series_in):
  x, = primals_in
  series, = series_in
  primal_out = prim.bind(x)
  c0, cs = jet(deriv, primals_in, series_in)
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out


def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
  """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
  jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
Example #30
0
File: poisson.py Project: lmmx/mcx
 def logpdf(self, x):
     x = x * 1.0
     return lax.add(
         xlogy(x, lax.log(self.lmbda)),
         lax.add(lax.neg(self.lmbda), lax.neg(lax.lgamma(x))),
     )