def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) if b is None: out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b), axis=dims, keepdims=keepdims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return out
def zeta(x, q=None): assert q is not None, "Riemann zeta function is not implemented yet." # Reference: Johansson, Fredrik. # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives." # Numerical Algorithms 69.2 (2015): 253-270. # https://arxiv.org/abs/1309.2877 - formula (5) # here we keep the same notation as in reference s, a = _promote_args_inexact("zeta", x, q) dtype = lax.dtype(a).type s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1) # precision ~ N, M N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16) assert M <= len(_BERNOULLI_COEFS) k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim))) S = jnp.sum((a_ + k)**-s_, -1) I = lax.div((a + N)**(dtype(1) - s), s - dtype(1)) T0 = (a + N)**-s m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) s_over_a = (s_ + m) / (a_ + N) T1 = jnp.cumprod(s_over_a, -1)[..., ::2] T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max) coefs = np.expand_dims( np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), tuple(range(a.ndim))) T1 = T1 / coefs T = T0 * (dtype(0.5) + T1.sum(-1)) return S + I + T
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale) pi = _constant_like(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg( lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc) zero = lax._const(k, 0) one = lax._const(k, 1) x = lax.sub(k, loc) log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
def xlog1py(x, y): x, y = _promote_args_inexact("xlog1py", x, y) x_ok = x != 0. safe_x = jnp.where(x_ok, x, 1.) safe_y = jnp.where(x_ok, y, 1.) return jnp.where(x_ok, lax.mul(safe_x, lax.log1p(safe_y)), jnp.zeros_like(x))
def expn(n, x): n, x = _promote_args_inexact("expn", n, x) _c = _lax_const zero = _c(x, 0) one = _c(x, 1) conds = [ (n < _c(n, 0)) | (x < zero), (x == zero) & (n < _c(n, 2)), (x == zero) & (n >= _c(n, 2)), (n == _c(n, 0)) & (x >= zero), (n >= _c(n, 5000)), (x > one), ] n1 = jnp.where(n == _c(n, 1), n + n, n) vals = [ jnp.nan, jnp.inf, one / n1, # prevent div by zero jnp.exp(-x) / x, partial(_expn3, n), partial(_expn2, n), partial(_expn1, n), ] ret = jnp.piecewise(x, conds, vals) return ret
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale) two = _constant_like(x, 2) scale_sqrd = lax.pow(scale, two) log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd)) quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd) return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc) zero = lax._const(k, 0) one = lax._const(k, 1) x = lax.sub(k, loc) log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p) return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), -jnp.inf, log_probs)
def cdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale) half = _constant_like(x, 0.5) one = _constant_like(x, 1) zero = _constant_like(x, 0) diff = lax.div(lax.sub(x, loc), scale) return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)), lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
def logpdf(x, a, loc=0, scale=1): x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale) one = _lax_const(x, 1) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) return where(lax.lt(x, loc), -inf, log_probs)
def logpdf(x, b, loc=0, scale=1): x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale) one = _constant_like(x, 1) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.div(scale, b)) log_probs = lax.neg( lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
def logpmf(k, n, p, loc=0): """JAX implementation of scipy.stats.nbinom.logpmf.""" k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc) one = _lax_const(k, 1) y = lax.sub(k, loc) comb_term = lax.sub(lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one))) log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p))) log_probs = lax.add(comb_term, log_linear_term) return where(lax.lt(k, loc), -inf, log_probs)
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d = _promote_args_inexact("multigammaln", a, d) constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d), lax.sub(d, _constant_like(a, 1))), lax.log(_constant_like(a, np.pi))) res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) - lax.div(jnp.arange(d), _constant_like(a, 2))), axis=-1) return res + constant
def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = lax._const(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale) two = _lax_const(x, 2) scaled_x = lax.div(lax.sub(x, loc), scale) df_over_two = lax.div(df, two) df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5)) normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi)) normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two) normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp), lax.lgamma(df_plus_one_over_two)) quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
def multigammaln(a, d): a, = _promote_args_inexact("multigammaln", a) d = lax.convert_element_type(d, lax.dtype(a)) constant = lax.mul( lax.mul(lax.mul(_constant_like(a, 0.25), d), lax.sub(d, _constant_like(a, 1))), lax.log(_constant_like(a, np.pi))) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - lax.div(jnp.arange(d), _constant_like(a, 2))), axis=-1) return res + constant
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = _promote_args_inexact("multigammaln", a, d) constant = lax.mul( lax.mul(lax.mul(_lax_const(a, 0.25), d_), lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi))) b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2)) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=tuple(range(a.ndim)))), axis=-1) return res + constant
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) one = _constant_like(x, 1) two = _constant_like(x, 2) y = lax.div(lax.sub(x, loc), scale) df_on_two = lax.div(df, two) kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) return where(lax.lt(x, loc), -inf, log_probs)
def logpmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.logpmf.""" k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) y = lax.sub(lax.floor(k), loc) one = _lax_const(y, 1) zero = _lax_const(y, 0) combiln = lax.neg( lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n, y), one), lax.add(y, one)))) beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)), betaln(a, b)) log_probs = lax.add(combiln, beta_lns) y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) log_probs = where(y_cond, -inf, log_probs) n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) return where(n_a_b_cond, nan, log_probs)
def i1(x): x, = _promote_args_inexact("i1", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
def i1e(x): x, = _promote_args_inexact("i1e", x) return lax.bessel_i1e(x)
def gammaincc(a, x): a, x = _promote_args_inexact("gammaincc", a, x) return lax.igammac(a, x)
def digamma(x): x, = _promote_args_inexact("digamma", x) return lax.digamma(x)
def betainc(a, b, x): a, b, x = _promote_args_inexact("betainc", a, b, x) return lax.betainc(a, b, x)
def betaln(x, y): x, y = _promote_args_inexact("betaln", x, y) return lax.lgamma(x) + lax.lgamma(y) - lax.lgamma(x + y)
def polygamma(n, x): assert jnp.issubdtype(lax.dtype(n), jnp.integer) n, x = _promote_args_inexact("polygamma", n, x) shape = lax.broadcast_shapes(n.shape, x.shape) return _polygamma(jnp.broadcast_to(n, shape), jnp.broadcast_to(x, shape))
def erfc(x): x, = _promote_args_inexact("erfc", x) return lax.erfc(x)
def erfinv(x): x, = _promote_args_inexact("erfinv", x) return lax.erf_inv(x)
def gammaln(x): x, = _promote_args_inexact("gammaln", x) return lax.lgamma(x)