def expn(n, x): n, x = _promote_args_inexact("expn", n, x) _c = _constant_like zero = _c(x, 0) one = _c(x, 1) conds = [ (n < _c(n, 0)) | (x < zero), (x == zero) & (n < _c(n, 2)), (x == zero) & (n >= _c(n, 2)), (n == _c(n, 0)) & (x >= zero), (n >= _c(n, 5000)), (x > one), ] n1 = jnp.where(n == _c(n, 1), n + n, n) vals = [ jnp.nan, jnp.inf, one / n1, # prevent div by zero jnp.exp(-x) / x, partial(_expn3, n), partial(_expn2, n), partial(_expn1, n), ] ret = jnp.piecewise(x, conds, vals) return ret
def _eval_expint_k(A, B, x): # helper function for all subsequent intervals A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]] one = _constant_like(x, 1.0) w = one / x f = jnp.polyval(A, w) / jnp.polyval(B, w) f = w * f + one return jnp.exp(x) * w * f
def _expn3(n, x): # n >= 5000 _c = _constant_like one = _c(x, 1.0) xk = x + n yk = one / (xk * xk) t = n ans = yk * t * (_c(x, 6) * x * x - _c(x, 8) * t * x + t * t) ans = yk * (ans + t * (t - _c(x, 2) * x)) ans = yk * (ans + t) return (ans + one) * jnp.exp(-x) / xk
def det(a): a = _promote_arg_dtypes(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2: return _det_2x2(a) elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3: return _det_3x3(a) elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]: sign, logdet = slogdet(a) return sign * jnp.exp(logdet) else: msg = "Argument to _det() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape))
def _expn2(n, x): # x > 1. _c = _constant_like BIG = _c(x, 1.44115188075855872e17) MACHEP = jnp.finfo(BIG.dtype).eps # ? zero = _c(x, 0.0) one = _c(x, 1.0) init = dict( k=_c(n, 1), pkm2=one, qkm2=x, pkm1=one, qkm1=x + n, ans=one / (x + n), t=_c(x, jnp.inf), r=zero, x=x, ) def body(d): x = d["x"] d["k"] += _c(d["k"], 1) k = d["k"] odd = k % _c(k, 2) == _c(k, 1) yk = jnp.where(odd, one, x) xk = jnp.where(odd, n + (k - _c(k, 1)) / _c(k, 2), k / _c(k, 2)) pk = d["pkm1"] * yk + d["pkm2"] * xk qk = d["qkm1"] * yk + d["qkm2"] * xk nz = qk != zero d["r"] = r = jnp.where(nz, pk / qk, d["r"]) d["t"] = jnp.where(nz, abs((d["ans"] - r) / r), one) d["ans"] = jnp.where(nz, r, d["ans"]) d["pkm2"] = d["pkm1"] d["pkm1"] = pk d["qkm2"] = d["qkm1"] d["qkm1"] = qk is_big = abs(pk) > BIG for s in "pq": for i in "12": key = s + "km" + i d[key] = jnp.where(is_big, d[key] / BIG, d[key]) return d def cond(d): return (d["x"] > _c(d["k"], 0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) return d["ans"] * jnp.exp(-x)
def _expn1(n, x): # exponential integral En _c = _constant_like x = jnp.array(x) MACHEP = jnp.finfo(x.dtype).eps zero = _c(x, 0.0) one = _c(x, 1.0) psi = -jnp.euler_gamma - jnp.log(x) psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi) n1 = jnp.where(n == _c(n, 1), one + one, n) init = dict( x=x, z=-x, xk=zero, yk=one, pk=one - n, ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)), t=jnp.inf, ) def body(d): d["xk"] += one d["yk"] *= d["z"] / d["xk"] d["pk"] += one d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero) d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one) return d def cond(d): return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) t = n r = n - _c(n, 1) return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]
def _polygamma(n, x): dtype = lax.dtype(n).type n_plus = n + dtype(1) sign = dtype(1) - (n_plus % dtype(2)) * dtype(2) return jnp.where(n == 0, digamma(x), sign * jnp.exp(gammaln(n_plus)) * zeta(n_plus, x))
def pmf(k, p, loc=0): return jnp.exp(logpmf(k, p, loc))
def expi_jvp(primals, tangents): (x, ) = primals (x_dot, ) = tangents return expi(x), jnp.exp(x) / x * x_dot