Exemple #1
0
def update_posterior(gram, XT_y, prior):
    alpha, beta = prior
    L = cholesky(jnp.diag(alpha) + beta * gram)
    R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)

    return mean, sigma
Exemple #2
0
 def chol_gram_blocks(dc_du, dc_dv):
     """Calculate Cholesky factors of decomposition of Gram matrix. """
     if isinstance(metric, IdentityMatrix):
         D = tuple(
             np.einsum('...ij,...kj', dc_dv[i], dc_dv[i])
             for i in range(3))
     else:
         m_v = split(
             metric_2_diag,
             (dc_dv[0].shape[1], dc_dv[1].shape[0] * dc_dv[1].shape[2]))
         m_v[1] = m_v[1].reshape((dc_dv[1].shape[0], dc_dv[1].shape[2]))
         D = tuple(
             np.einsum('...ij,...kj', dc_dv[i] /
                       m_v[i][..., None, :], dc_dv[i])
             for i in range(3))
     chol_D = tuple(nla.cholesky(D[i]) for i in range(3))
     D_inv_dc_du = tuple(
         sla.cho_solve((chol_D[i], True), dc_du[i]) for i in range(3))
     chol_C = nla.cholesky(metric_1 + (
         dc_du[0].T @ D_inv_dc_du[0] +
         np.einsum('ijk,ijl->kl', dc_du[1], D_inv_dc_du[1]) +
         dc_du[2].T @ D_inv_dc_du[2]))
     return chol_C, chol_D
Exemple #3
0
def _multivariate_normal(key, mean, cov, shape, dtype):
  if not onp.ndim(mean) >= 1:
    msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}"
    raise ValueError(msg.format(onp.ndim(mean)))
  if not onp.ndim(cov) >= 2:
    msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}"
    raise ValueError(msg.format(onp.ndim(cov)))
  n = mean.shape[-1]
  if onp.shape(cov)[-2:] != (n, n):
    msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
           "but got cov.shape == {shape}.")
    raise ValueError(msg.format(n=n, shape=onp.shape(cov)))

  if shape is None:
    shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
  else:
    _check_shape("normal", shape, mean.shape[:-1], mean.shape[:-2])

  chol_factor = cholesky(cov)
  normal_samples = normal(key, shape + mean.shape[-1:], dtype)
  return mean + np.tensordot(normal_samples, chol_factor, [-1, 1])
 def _calc_canon_mat(X: Arrays, Y: Arrays, λs) -> np.DeviceArray:
     # https://stackoverflow.com/questions/15670094/speed-up-solving-a-triangular-linear-system-with-numpy
     K = (inv(cholesky(CanonicalRidge._ridge_cov(X, λs[0]))) @ (X.T @ Y) /
          (X.T.shape[0] * Y.shape[1]) @ inv(
              cholesky(CanonicalRidge._ridge_cov(Y, λs[1]))))
     return K
Exemple #5
0
def cholesky(a):
    if isinstance(a, JaxArray): a = a.value
    return JaxArray(linalg.cholesky(a))