def _fft_norm(s, func_name, norm): if norm is None or norm == "backward": return 1 elif norm == "ortho": return jnp.sqrt(jnp.prod(s)) if func_name.startswith('i') else 1/jnp.sqrt(jnp.prod(s)) elif norm == "forward": return jnp.prod(s) if func_name.startswith('i') else 1/jnp.prod(s) raise ValueError(f'Invalid norm value {norm}; should be "backward",' '"ortho" or "forward".')
def _slogdet_qr(a): # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det
def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)