def update_posterior(gram, XT_y, prior): alpha, beta = prior L = cholesky(jnp.diag(alpha) + beta * gram) R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True) sigma = jnp.dot(R.T, R) mean = beta * jnp.dot(sigma, XT_y) return mean, sigma
def chol_gram_blocks(dc_du, dc_dv): """Calculate Cholesky factors of decomposition of Gram matrix. """ if isinstance(metric, IdentityMatrix): D = tuple( np.einsum('...ij,...kj', dc_dv[i], dc_dv[i]) for i in range(3)) else: m_v = split( metric_2_diag, (dc_dv[0].shape[1], dc_dv[1].shape[0] * dc_dv[1].shape[2])) m_v[1] = m_v[1].reshape((dc_dv[1].shape[0], dc_dv[1].shape[2])) D = tuple( np.einsum('...ij,...kj', dc_dv[i] / m_v[i][..., None, :], dc_dv[i]) for i in range(3)) chol_D = tuple(nla.cholesky(D[i]) for i in range(3)) D_inv_dc_du = tuple( sla.cho_solve((chol_D[i], True), dc_du[i]) for i in range(3)) chol_C = nla.cholesky(metric_1 + ( dc_du[0].T @ D_inv_dc_du[0] + np.einsum('ijk,ijl->kl', dc_du[1], D_inv_dc_du[1]) + dc_du[2].T @ D_inv_dc_du[2])) return chol_C, chol_D
def _multivariate_normal(key, mean, cov, shape, dtype): if not onp.ndim(mean) >= 1: msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}" raise ValueError(msg.format(onp.ndim(mean))) if not onp.ndim(cov) >= 2: msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}" raise ValueError(msg.format(onp.ndim(cov))) n = mean.shape[-1] if onp.shape(cov)[-2:] != (n, n): msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, " "but got cov.shape == {shape}.") raise ValueError(msg.format(n=n, shape=onp.shape(cov))) if shape is None: shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) else: _check_shape("normal", shape, mean.shape[:-1], mean.shape[:-2]) chol_factor = cholesky(cov) normal_samples = normal(key, shape + mean.shape[-1:], dtype) return mean + np.tensordot(normal_samples, chol_factor, [-1, 1])
def _calc_canon_mat(X: Arrays, Y: Arrays, λs) -> np.DeviceArray: # https://stackoverflow.com/questions/15670094/speed-up-solving-a-triangular-linear-system-with-numpy K = (inv(cholesky(CanonicalRidge._ridge_cov(X, λs[0]))) @ (X.T @ Y) / (X.T.shape[0] * Y.shape[1]) @ inv( cholesky(CanonicalRidge._ridge_cov(Y, λs[1])))) return K
def cholesky(a): if isinstance(a, JaxArray): a = a.value return JaxArray(linalg.cholesky(a))