Exemplo n.º 1
0
def expn(n, x):
    n, x = _promote_args_inexact("expn", n, x)
    _c = _constant_like
    zero = _c(x, 0)
    one = _c(x, 1)
    conds = [
        (n < _c(n, 0)) | (x < zero),
        (x == zero) & (n < _c(n, 2)),
        (x == zero) & (n >= _c(n, 2)),
        (n == _c(n, 0)) & (x >= zero),
        (n >= _c(n, 5000)),
        (x > one),
    ]
    n1 = jnp.where(n == _c(n, 1), n + n, n)
    vals = [
        jnp.nan,
        jnp.inf,
        one / n1,  # prevent div by zero
        jnp.exp(-x) / x,
        partial(_expn3, n),
        partial(_expn2, n),
        partial(_expn1, n),
    ]
    ret = jnp.piecewise(x, conds, vals)
    return ret
Exemplo n.º 2
0
def _eval_expint_k(A, B, x):
    # helper function for all subsequent intervals
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    one = _constant_like(x, 1.0)
    w = one / x
    f = jnp.polyval(A, w) / jnp.polyval(B, w)
    f = w * f + one
    return jnp.exp(x) * w * f
Exemplo n.º 3
0
def _expn3(n, x):
    # n >= 5000
    _c = _constant_like
    one = _c(x, 1.0)
    xk = x + n
    yk = one / (xk * xk)
    t = n
    ans = yk * t * (_c(x, 6) * x * x - _c(x, 8) * t * x + t * t)
    ans = yk * (ans + t * (t - _c(x, 2) * x))
    ans = yk * (ans + t)
    return (ans + one) * jnp.exp(-x) / xk
Exemplo n.º 4
0
def det(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    a_shape = jnp.shape(a)
    if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
        return _det_2x2(a)
    elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3:
        return _det_3x3(a)
    elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
        sign, logdet = slogdet(a)
        return sign * jnp.exp(logdet)
    else:
        msg = "Argument to _det() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
Exemplo n.º 5
0
def _expn2(n, x):
    # x > 1.
    _c = _constant_like
    BIG = _c(x, 1.44115188075855872e17)
    MACHEP = jnp.finfo(BIG.dtype).eps  # ?
    zero = _c(x, 0.0)
    one = _c(x, 1.0)

    init = dict(
        k=_c(n, 1),
        pkm2=one,
        qkm2=x,
        pkm1=one,
        qkm1=x + n,
        ans=one / (x + n),
        t=_c(x, jnp.inf),
        r=zero,
        x=x,
    )

    def body(d):
        x = d["x"]
        d["k"] += _c(d["k"], 1)
        k = d["k"]
        odd = k % _c(k, 2) == _c(k, 1)
        yk = jnp.where(odd, one, x)
        xk = jnp.where(odd, n + (k - _c(k, 1)) / _c(k, 2), k / _c(k, 2))
        pk = d["pkm1"] * yk + d["pkm2"] * xk
        qk = d["qkm1"] * yk + d["qkm2"] * xk
        nz = qk != zero
        d["r"] = r = jnp.where(nz, pk / qk, d["r"])
        d["t"] = jnp.where(nz, abs((d["ans"] - r) / r), one)
        d["ans"] = jnp.where(nz, r, d["ans"])
        d["pkm2"] = d["pkm1"]
        d["pkm1"] = pk
        d["qkm2"] = d["qkm1"]
        d["qkm1"] = qk
        is_big = abs(pk) > BIG
        for s in "pq":
            for i in "12":
                key = s + "km" + i
                d[key] = jnp.where(is_big, d[key] / BIG, d[key])
        return d

    def cond(d):
        return (d["x"] > _c(d["k"], 0)) & (d["t"] > MACHEP)

    d = lax.while_loop(cond, body, init)
    return d["ans"] * jnp.exp(-x)
Exemplo n.º 6
0
def _expn1(n, x):
    # exponential integral En
    _c = _constant_like
    x = jnp.array(x)
    MACHEP = jnp.finfo(x.dtype).eps

    zero = _c(x, 0.0)
    one = _c(x, 1.0)
    psi = -jnp.euler_gamma - jnp.log(x)
    psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi)
    n1 = jnp.where(n == _c(n, 1), one + one, n)
    init = dict(
        x=x,
        z=-x,
        xk=zero,
        yk=one,
        pk=one - n,
        ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)),
        t=jnp.inf,
    )

    def body(d):
        d["xk"] += one
        d["yk"] *= d["z"] / d["xk"]
        d["pk"] += one
        d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero)
        d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one)
        return d

    def cond(d):
        return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP)

    d = lax.while_loop(cond, body, init)
    t = n
    r = n - _c(n, 1)
    return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]
Exemplo n.º 7
0
def _polygamma(n, x):
    dtype = lax.dtype(n).type
    n_plus = n + dtype(1)
    sign = dtype(1) - (n_plus % dtype(2)) * dtype(2)
    return jnp.where(n == 0, digamma(x),
                     sign * jnp.exp(gammaln(n_plus)) * zeta(n_plus, x))
Exemplo n.º 8
0
def pmf(k, p, loc=0):
  return jnp.exp(logpmf(k, p, loc))
Exemplo n.º 9
0
def expi_jvp(primals, tangents):
    (x, ) = primals
    (x_dot, ) = tangents
    return expi(x), jnp.exp(x) / x * x_dot