Beispiel #1
0
def logpmf(k, p, loc=0):
    k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc)
    zero = _lax_const(k, 0)
    one = _lax_const(k, 1)
    x = lax.sub(k, loc)
    log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
    return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
Beispiel #2
0
def logpmf(k, p, loc=0):
    k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc)
    zero = _lax_const(k, 0)
    one = _lax_const(k, 1)
    x = lax.sub(k, loc)
    log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
    return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), -jnp.inf,
                     log_probs)
Beispiel #3
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = _lax_const(x, 0.5)
    one = _lax_const(x, 1)
    zero = _lax_const(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
Beispiel #4
0
def logpdf(x, df, loc=0, scale=1):
  x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
  two = _lax_const(x, 2)
  scaled_x = lax.div(lax.sub(x, loc), scale)
  df_over_two = lax.div(df, two)
  df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
  normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
  normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
  normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
                           lax.lgamma(df_plus_one_over_two))
  quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
  return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
Beispiel #5
0
def multigammaln(a, d):
    d = core.concrete_or_error(int, d, "d argument of multigammaln")
    a, d_ = _promote_args_inexact("multigammaln", a, d)

    constant = lax.mul(
        lax.mul(lax.mul(_lax_const(a, 0.25), d_),
                lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi)))
    b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
                  axis=-1)
    return res + constant
Beispiel #6
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _lax_const(x, 1)
    two = _lax_const(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)),
                     lax.div(y, two))

    nrml_cnst = lax.neg(
        lax.add(lax.lgamma(df_on_two), lax.div(lax.mul(lax.log(two), df),
                                               two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
Beispiel #7
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = _lax_const(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
Beispiel #8
0
def logpdf(x, b, loc=0, scale=1):
    x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
    one = _lax_const(x, 1)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.div(scale, b))
    log_probs = lax.neg(
        lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
    return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
Beispiel #9
0
def _eval_expint_k(A, B, x):
    # helper function for all subsequent intervals
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    one = _lax_const(x, 1.0)
    w = one / x
    f = jnp.polyval(A, w) / jnp.polyval(B, w)
    f = w * f + one
    return jnp.exp(x) * w * f
Beispiel #10
0
def logpdf(x, a, loc=0, scale=1):
    x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
    one = _lax_const(x, 1)
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
    shape_terms = lax.add(gammaln(a), lax.log(scale))
    log_probs = lax.sub(log_linear_term, shape_terms)
    return where(lax.lt(x, loc), -inf, log_probs)
Beispiel #11
0
def logpmf(k, n, a, b, loc=0):
    """JAX implementation of scipy.stats.betabinom.logpmf."""
    k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b,
                                            loc)
    y = lax.sub(lax.floor(k), loc)
    one = _lax_const(y, 1)
    zero = _lax_const(y, 0)
    combiln = lax.neg(
        lax.add(lax.log1p(n),
                betaln(lax.add(lax.sub(n, y), one), lax.add(y, one))))
    beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)),
                       betaln(a, b))
    log_probs = lax.add(combiln, beta_lns)
    y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
    log_probs = where(y_cond, -inf, log_probs)
    n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)),
                            lax.lt(b, zero))
    return where(n_a_b_cond, nan, log_probs)
Beispiel #12
0
def logpmf(k, n, p, loc=0):
    """JAX implementation of scipy.stats.nbinom.logpmf."""
    k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc)
    one = _lax_const(k, 1)
    y = lax.sub(k, loc)
    comb_term = lax.sub(lax.sub(gammaln(lax.add(y, n)), gammaln(n)),
                        gammaln(lax.add(y, one)))
    log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
    log_probs = lax.add(comb_term, log_linear_term)
    return where(lax.lt(k, loc), -inf, log_probs)
Beispiel #13
0
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = _lax_const(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Beispiel #14
0
def logpdf(x, alpha):
    x, alpha = _promote_dtypes_inexact(x, alpha)
    if alpha.ndim != 1:
        raise ValueError(
            f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}")
    if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
        raise ValueError(
            "`x` must have either the same number of entries as `alpha` "
            f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
        )
    one = _lax_const(x, 1)
    if x.shape[0] != alpha.shape[0]:
        x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
    normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
    if x.ndim > 1:
        alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1, ) * (x.ndim - 1),
                                     (0, ))
    log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0),
                        normalize_term)
    return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
Beispiel #15
0
def expit(x):
    x, = _promote_args_inexact("expit", x)
    one = _lax_const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Beispiel #16
0
def _norm_logpdf(x):
    neg_half = _lax_const(x, -0.5)
    log_normalizer = _lax_const(x, _norm_logpdf_constant)
    return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
Beispiel #17
0
def expn_jvp(n, primals, tangents):
    (x, ), (x_dot, ) = primals, tangents
    return expn(n, x), lax.mul(lax.neg(x_dot),
                               expn(lax.sub(n, _lax_const(n, 1)), x))
Beispiel #18
0
def entr(x):
    x, = _promote_args_inexact("entr", x)
    return lax.select(lax.lt(x, _lax_const(x, 0)), lax.full_like(x, -np.inf),
                      lax.neg(xlogy(x, x)))
Beispiel #19
0
def expit(x):
    x = asarray(x)
    one = _lax_const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Beispiel #20
0
def logit(x):
    x, = _promote_args_inexact("logit", x)
    return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
Beispiel #21
0
@_wraps(osp_special.erfinv)
def erfinv(x):
    x, = _promote_args_inexact("erfinv", x)
    return lax.erf_inv(x)


@api.custom_jvp
@_wraps(osp_special.logit, update_doc=False)
def logit(x):
    x = asarray(x)
    return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))


logit.defjvps(
    lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))


@api.custom_jvp
@_wraps(osp_special.expit, update_doc=False)
def expit(x):
    x = asarray(x)
    one = _lax_const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))


expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans))


@_wraps(osp_special.logsumexp)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
Beispiel #22
0
def logpdf(x):
    x, = _promote_args_inexact("logistic.logpdf", x)
    two = _lax_const(x, 2)
    half_x = lax.div(x, two)
    return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
Beispiel #23
0
def logit(x):
    x = asarray(x)
    return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
Beispiel #24
0
def cdf(k, mu, loc=0):
  k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
  zero = _lax_const(k, 0)
  x = lax.sub(k, loc)
  p = gammaincc(jnp.floor(1 + x), mu)
  return jnp.where(lax.lt(x, zero), zero, p)
Beispiel #25
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
    two = _lax_const(x, 2)
    linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
    return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
Beispiel #26
0
def logpmf(k, mu, loc=0):
  k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
  zero = _lax_const(k, 0)
  x = lax.sub(k, loc)
  log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
  return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
Beispiel #27
0
def logpdf(x, loc=0, scale=1):
  x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
  scale_sqrd = lax.square(scale)
  log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd))
  quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
  return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))