def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _slogdet_qr(a): # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _expint1(x): # 0 < x <= 2 A = [ -5.350447357812542947283e0, 2.185049168816613393830e2, -4.176572384826693777058e3, 5.541176756393557601232e4, -3.313381331178144034309e5, 1.592627163384945414220e6, ] B = [ 1.0, -5.250547959112862969197e1, 1.259616186786790571525e3, -1.756549581973534652631e4, 1.493062117002725991967e5, -7.294949239640527645655e5, 1.592627163384945429726e6, ] A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]] f = jnp.polyval(A, x) / jnp.polyval(B, x) return x * f + jnp.euler_gamma + jnp.log(x)
def _expn1(n, x): # exponential integral En _c = _constant_like x = jnp.array(x) MACHEP = jnp.finfo(x.dtype).eps zero = _c(x, 0.0) one = _c(x, 1.0) psi = -jnp.euler_gamma - jnp.log(x) psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi) n1 = jnp.where(n == _c(n, 1), one + one, n) init = dict( x=x, z=-x, xk=zero, yk=one, pk=one - n, ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)), t=jnp.inf, ) def body(d): d["xk"] += one d["yk"] *= d["z"] / d["xk"] d["pk"] += one d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero) d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one) return d def cond(d): return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) t = n r = n - _c(n, 1) return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]