Exemple #1
0
def _slogdet_lu(a):
    dtype = lax.dtype(a)
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1))
    parity = jnp.count_nonzero(pivot != iota, axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Exemple #2
0
def _slogdet_qr(a):
  # Implementation of slogdet using QR decomposition. One reason we might prefer
  # QR decomposition is that it is more amenable to a fast batched
  # implementation on TPU because of the lack of row pivoting.
  if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
    raise NotImplementedError("slogdet method='qr' not implemented for complex "
                              "inputs")
  n = a.shape[-1]
  a, taus = lax_linalg.geqrf(a)
  # The determinant of a triangular matrix is the product of its diagonal
  # elements. We are working in log space, so we compute the magnitude as the
  # the trace of the log-absolute values, and we compute the sign separately.
  log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
  sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
  # The determinant of a Householder reflector is -1. So whenever we actually
  # made a reflection (tau != 0), multiply the result by -1.
  sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
  return sign_diag * sign_taus, log_abs_det
Exemple #3
0
def slogdet(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    dtype = lax.dtype(a)
    a_shape = jnp.shape(a)
    if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
        msg = "Argument to slogdet() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Exemple #4
0
def _expint1(x):
    # 0 < x <= 2
    A = [
        -5.350447357812542947283e0,
        2.185049168816613393830e2,
        -4.176572384826693777058e3,
        5.541176756393557601232e4,
        -3.313381331178144034309e5,
        1.592627163384945414220e6,
    ]
    B = [
        1.0,
        -5.250547959112862969197e1,
        1.259616186786790571525e3,
        -1.756549581973534652631e4,
        1.493062117002725991967e5,
        -7.294949239640527645655e5,
        1.592627163384945429726e6,
    ]
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    f = jnp.polyval(A, x) / jnp.polyval(B, x)
    return x * f + jnp.euler_gamma + jnp.log(x)
Exemple #5
0
def _expn1(n, x):
    # exponential integral En
    _c = _constant_like
    x = jnp.array(x)
    MACHEP = jnp.finfo(x.dtype).eps

    zero = _c(x, 0.0)
    one = _c(x, 1.0)
    psi = -jnp.euler_gamma - jnp.log(x)
    psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi)
    n1 = jnp.where(n == _c(n, 1), one + one, n)
    init = dict(
        x=x,
        z=-x,
        xk=zero,
        yk=one,
        pk=one - n,
        ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)),
        t=jnp.inf,
    )

    def body(d):
        d["xk"] += one
        d["yk"] *= d["z"] / d["xk"]
        d["pk"] += one
        d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero)
        d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one)
        return d

    def cond(d):
        return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP)

    d = lax.while_loop(cond, body, init)
    t = n
    r = n - _c(n, 1)
    return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]