Ejemplo n.º 1
0
def _slogdet_jvp(primals, tangents):
    x, = primals
    g, = tangents
    sign, ans = slogdet(x)
    ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2)
    if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
        sign_dot = (ans_dot - np.real(ans_dot)) * sign
        ans_dot = np.real(ans_dot)
    else:
        sign_dot = jnp.zeros_like(sign)
    return (sign, ans), (sign_dot, ans_dot)
Ejemplo n.º 2
0
def _slogdet_qr(a):
  # Implementation of slogdet using QR decomposition. One reason we might prefer
  # QR decomposition is that it is more amenable to a fast batched
  # implementation on TPU because of the lack of row pivoting.
  if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
    raise NotImplementedError("slogdet method='qr' not implemented for complex "
                              "inputs")
  n = a.shape[-1]
  a, taus = lax_linalg.geqrf(a)
  # The determinant of a triangular matrix is the product of its diagonal
  # elements. We are working in log space, so we compute the magnitude as the
  # the trace of the log-absolute values, and we compute the sign separately.
  log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
  sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
  # The determinant of a Householder reflector is -1. So whenever we actually
  # made a reflection (tau != 0), multiply the result by -1.
  sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
  return sign_diag * sign_taus, log_abs_det
Ejemplo n.º 3
0
def split_spectrum(H, n, split_point, V0=None):
    """ The Hermitian matrix `H` is split into two matrices `H_minus`
  `H_plus`, respectively sharing its eigenspaces beneath and above
  its `split_point`th eigenvalue.

  Returns, in addition, `V_minus` and `V_plus`, isometries such that
  `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
  returned instead; this allows the overall isometries mapping from
  an initial input matrix to progressively smaller blocks to be formed.

  Args:
    H: The Hermitian matrix to split.
    split_point: The eigenvalue to split along.
    V0: Matrix of isometries to be updated.
  Returns:
    H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath
      `split_point`.
    V_minus: An isometry from the input space of `V0` to `H_minus`.
    H_plus: A Hermitian matrix sharing the eigenvalues of `H` above
      `split_point`.
    V_plus: An isometry from the input space of `V0` to `H_plus`.
    rank: The dynamic size of the m subblock.
  """
    N, _ = H.shape
    H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(
        H.dtype)
    U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
    P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n)))
    rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32)

    V_minus, V_plus = _projector_subspace(P, H, n, rank)
    H_minus = (V_minus.conj().T @ H) @ V_minus
    H_plus = (V_plus.conj().T @ H) @ V_plus
    if V0 is not None:
        V_minus = jnp.dot(V0, V_minus)
        V_plus = jnp.dot(V0, V_plus)
    return H_minus, V_minus, H_plus, V_plus, rank
Ejemplo n.º 4
0
def _det_jvp(primals, tangents):
    x, = primals
    g, = tangents
    y, z = _cofactor_solve(x, g)
    return y, jnp.trace(z, axis1=-1, axis2=-2)