def _slogdet_jvp(primals, tangents): x, = primals g, = tangents sign, ans = slogdet(x) ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2) if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): sign_dot = (ans_dot - np.real(ans_dot)) * sign ans_dot = np.real(ans_dot) else: sign_dot = jnp.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot)
def _slogdet_qr(a): # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det
def split_spectrum(H, n, split_point, V0=None): """ The Hermitian matrix `H` is split into two matrices `H_minus` `H_plus`, respectively sharing its eigenspaces beneath and above its `split_point`th eigenvalue. Returns, in addition, `V_minus` and `V_plus`, isometries such that `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are returned instead; this allows the overall isometries mapping from an initial input matrix to progressively smaller blocks to be formed. Args: H: The Hermitian matrix to split. split_point: The eigenvalue to split along. V0: Matrix of isometries to be updated. Returns: H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath `split_point`. V_minus: An isometry from the input space of `V0` to `H_minus`. H_plus: A Hermitian matrix sharing the eigenvalues of `H` above `split_point`. V_plus: An isometry from the input space of `V0` to `H_plus`. rank: The dynamic size of the m subblock. """ N, _ = H.shape H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype( H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n))) rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32) V_minus, V_plus = _projector_subspace(P, H, n, rank) H_minus = (V_minus.conj().T @ H) @ V_minus H_plus = (V_plus.conj().T @ H) @ V_plus if V0 is not None: V_minus = jnp.dot(V0, V_minus) V_plus = jnp.dot(V0, V_plus) return H_minus, V_minus, H_plus, V_plus, rank
def _det_jvp(primals, tangents): x, = primals g, = tangents y, z = _cofactor_solve(x, g) return y, jnp.trace(z, axis1=-1, axis2=-2)