def _slogdet_qr(a): # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): # Use jnp.array(nan) to avoid false positives in debug_nans # (see https://github.com/google/jax/issues/7634) out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): out = jnp.where(sign < 0, np.nan, out) return out