def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc) zero = _lax_const(k, 0) one = _lax_const(k, 1) x = lax.sub(k, loc) log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc) zero = _lax_const(k, 0) one = _lax_const(k, 1) x = lax.sub(k, loc) log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p) return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), -jnp.inf, log_probs)
def cdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale) half = _lax_const(x, 0.5) one = _lax_const(x, 1) zero = _lax_const(x, 0) diff = lax.div(lax.sub(x, loc), scale) return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)), lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale) two = _lax_const(x, 2) scaled_x = lax.div(lax.sub(x, loc), scale) df_over_two = lax.div(df, two) df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5)) normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi)) normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two) normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp), lax.lgamma(df_plus_one_over_two)) quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = _promote_args_inexact("multigammaln", a, d) constant = lax.mul( lax.mul(lax.mul(_lax_const(a, 0.25), d_), lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi))) b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2)) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=tuple(range(a.ndim)))), axis=-1) return res + constant
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) one = _lax_const(x, 1) two = _lax_const(x, 2) y = lax.div(lax.sub(x, loc), scale) df_on_two = lax.div(df, two) kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y, two)) nrml_cnst = lax.neg( lax.add(lax.lgamma(df_on_two), lax.div(lax.mul(lax.log(two), df), two))) log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) return where(lax.lt(x, loc), -inf, log_probs)
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale) pi = _lax_const(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg( lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
def logpdf(x, b, loc=0, scale=1): x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale) one = _lax_const(x, 1) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.div(scale, b)) log_probs = lax.neg( lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
def _eval_expint_k(A, B, x): # helper function for all subsequent intervals A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]] one = _lax_const(x, 1.0) w = one / x f = jnp.polyval(A, w) / jnp.polyval(B, w) f = w * f + one return jnp.exp(x) * w * f
def logpdf(x, a, loc=0, scale=1): x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale) one = _lax_const(x, 1) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) return where(lax.lt(x, loc), -inf, log_probs)
def logpmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.logpmf.""" k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) y = lax.sub(lax.floor(k), loc) one = _lax_const(y, 1) zero = _lax_const(y, 0) combiln = lax.neg( lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n, y), one), lax.add(y, one)))) beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)), betaln(a, b)) log_probs = lax.add(combiln, beta_lns) y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) log_probs = where(y_cond, -inf, log_probs) n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) return where(n_a_b_cond, nan, log_probs)
def logpmf(k, n, p, loc=0): """JAX implementation of scipy.stats.nbinom.logpmf.""" k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc) one = _lax_const(k, 1) y = lax.sub(k, loc) comb_term = lax.sub(lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one))) log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p))) log_probs = lax.add(comb_term, log_linear_term) return where(lax.lt(k, loc), -inf, log_probs)
def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = _lax_const(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)
def logpdf(x, alpha): x, alpha = _promote_dtypes_inexact(x, alpha) if alpha.ndim != 1: raise ValueError( f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}") if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1): raise ValueError( "`x` must have either the same number of entries as `alpha` " f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}" ) one = _lax_const(x, 1) if x.shape[0] != alpha.shape[0]: x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0) normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha)) if x.ndim > 1: alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1, ) * (x.ndim - 1), (0, )) log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term) return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
def expit(x): x, = _promote_args_inexact("expit", x) one = _lax_const(x, 1) return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
def _norm_logpdf(x): neg_half = _lax_const(x, -0.5) log_normalizer = _lax_const(x, _norm_logpdf_constant) return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
def expn_jvp(n, primals, tangents): (x, ), (x_dot, ) = primals, tangents return expn(n, x), lax.mul(lax.neg(x_dot), expn(lax.sub(n, _lax_const(n, 1)), x))
def entr(x): x, = _promote_args_inexact("entr", x) return lax.select(lax.lt(x, _lax_const(x, 0)), lax.full_like(x, -np.inf), lax.neg(xlogy(x, x)))
def expit(x): x = asarray(x) one = _lax_const(x, 1) return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
def logit(x): x, = _promote_args_inexact("logit", x) return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
@_wraps(osp_special.erfinv) def erfinv(x): x, = _promote_args_inexact("erfinv", x) return lax.erf_inv(x) @api.custom_jvp @_wraps(osp_special.logit, update_doc=False) def logit(x): x = asarray(x) return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) logit.defjvps( lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x)))) @api.custom_jvp @_wraps(osp_special.expit, update_doc=False) def expit(x): x = asarray(x) one = _lax_const(x, 1) return lax.div(one, lax.add(one, lax.exp(lax.neg(x)))) expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans)) @_wraps(osp_special.logsumexp) def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
def logpdf(x): x, = _promote_args_inexact("logistic.logpdf", x) two = _lax_const(x, 2) half_x = lax.div(x, two) return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
def logit(x): x = asarray(x) return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
def cdf(k, mu, loc=0): k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) x = lax.sub(k, loc) p = gammaincc(jnp.floor(1 + x), mu) return jnp.where(lax.lt(x, zero), zero, p)
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale) two = _lax_const(x, 2) linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale) return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
def logpmf(k, mu, loc=0): k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) x = lax.sub(k, loc) log_probs = xlogy(x, mu) - gammaln(x + 1) - mu return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale) scale_sqrd = lax.square(scale) log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd)) quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd) return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))