def kl_divergence_multivariate_normal(a_mean, a_scale_tril, b_mean, b_scale_tril, lower=True): def log_abs_determinant(scale_tril_arg): diag_scale_tril = np.diagonal(scale_tril_arg, axis1=-2, axis2=-1) return 2 * np.sum(np.log(diag_scale_tril), axis=-1) def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" return np.sum(np.square(x), axis=[-2, -1]) if b_scale_tril.shape[0] == 1: tiles = tuple([b_mean.shape[0]] + [1 for _ in range(len(b_scale_tril.shape) - 1)]) scale_tril = np.tile(b_scale_tril, tiles) else: scale_tril = b_scale_tril b_inv_a = solve_triangular(b_scale_tril, a_scale_tril, lower=lower) kl = 0.5 * (log_abs_determinant(b_scale_tril) - log_abs_determinant(a_scale_tril) - a_scale_tril.shape[-1] + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( solve_triangular(scale_tril, (b_mean - a_mean)[..., np.newaxis], lower=lower))) return kl
def posterior_sample(self, key, sample, X_star, **kwargs): # Fetch training data batch = kwargs['batch'] X = batch['X'] # Fetch params var = sample['kernel_var'] length = sample['kernel_length'] beta = sample['beta'] eta = sample['eta'] theta = np.concatenate([var, length]) # Compute kernels K_xx = self.kernel(X, X, theta) + np.eye(X.shape[0]) * 1e-8 k_pp = self.kernel(X_star, X_star, theta) + np.eye(X_star.shape[0]) * 1e-8 k_pX = self.kernel(X_star, X, theta) L = cholesky(K_xx, lower=True) f = np.matmul(L, eta) + beta tmp_1 = solve_triangular(L.T, solve_triangular(L, f, lower=True)) tmp_2 = solve_triangular(L.T, solve_triangular(L, k_pX.T, lower=True)) # Compute predictive mean mu = np.matmul(k_pX, tmp_1) cov = k_pp - np.matmul(k_pX, tmp_2) std = np.sqrt(np.clip(np.diag(cov), a_min=0.)) sample = mu + std * random.normal(key, mu.shape) return mu, sample
def posterior_sample(self, key, sample, X_star, **kwargs): # Fetch training data norm_const = kwargs['norm_const'] batch = kwargs['batch'] X, y = batch['X'], batch['y'] # Fetch params var = sample['kernel_var'] length = sample['kernel_length'] noise = sample['noise_var'] params = np.concatenate( [np.array([var]), np.array(length), np.array([noise])]) theta = params[:-1] # Compute kernels k_pp = self.kernel(X_star, X_star, theta) + np.eye(X_star.shape[0]) * (noise + 1e-8) k_pX = self.kernel(X_star, X, theta) L = self.compute_cholesky(params, batch) alpha = solve_triangular(L.T, solve_triangular(L, y, lower=True)) beta = solve_triangular(L.T, solve_triangular(L, k_pX.T, lower=True)) # Compute predictive mean, std mu = np.matmul(k_pX, alpha) cov = k_pp - np.matmul(k_pX, beta) std = np.sqrt(np.clip(np.diag(cov), a_min=0.)) sample = mu + std * random.normal(key, mu.shape) mu = mu * norm_const['sigma_y'] + norm_const['mu_y'] sample = sample * norm_const['sigma_y'] + norm_const['mu_y'] return mu, sample
def solve_tri(A, B, lower=True, from_left=True, transp_L=False): if not from_left: return sla.solve_triangular(A.T, B.T, trans=transp_L, lower=not lower).T else: return sla.solve_triangular(A, B, trans=transp_L, lower=lower)
def __phi_H_2(self, x, p, xtilde, ptilde): ptilde = ptilde - 0.5 * self.__epsilon * self.__hamiltonian.jacobian_at( xtilde, p) L = self.__target.metric(xtilde) # x = x + 0.5*self.__epsilon*np.linalg.solve([email protected],p) x = x + 0.5 * self.__epsilon * sla.solve_triangular( L.T, sla.solve_triangular(L, p, lower=False), lower=True) return x, p, xtilde, ptilde
def evaluate(self): K = self.model.kernel.function(self.model.X, self.model.parameters)\ + jnp.eye(self.N) * (self.model.parameters["noise"] + 1e-8) self.L = cholesky(K, lower=True) self.alpha = solve_triangular( self.L.T, solve_triangular(self.L, self.model.y, lower=True))
def mvn_kl(mu_a, L_a, mu_b, L_b): def squared_frobenius_norm(x): return np.sum(np.square(x)) b_inv_a = solve_triangular(L_b, L_a, lower=True) kl_div = ( np.sum(np.log(np.diag(L_b))) - np.sum(np.log(np.diag(L_a))) + 0.5 * (-L_a.shape[-1] + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( solve_triangular(L_b, mu_b[:, None] - mu_a[:, None], lower=True)))) return kl_div
def _pred_factorize(self, xtest): Kux = self.kernel.cross_covariance(self.x_u, xtest) Ws = solve_triangular(self.Luu, Kux, lower=True) # pack pack = jnp.concatenate([self.W_Dinv_y, Ws], axis=1) Linv_pack = solve_triangular(self.L, pack, lower=True) # unpack Linv_W_Dinv_y = Linv_pack[:, : self.W_Dinv_y.shape[1]] Linv_Ws = Linv_pack[:, self.W_Dinv_y.shape[1] :] return Ws, Linv_W_Dinv_y, Linv_Ws
def _pred_factorize(params, xtest): Kux = rbf_kernel(params["x_u"], xtest, params["variance"], params["length_scale"]) Ws = solve_triangular(params["Luu"], Kux, lower=True) # pack pack = jnp.concatenate([params["W_Dinv_y"], Ws], axis=1) Linv_pack = solve_triangular(params["L"], pack, lower=True) # unpack Linv_W_Dinv_y = Linv_pack[:, :params["W_Dinv_y"].shape[1]] Linv_Ws = Linv_pack[:, params["W_Dinv_y"].shape[1]:] return Ws, Linv_W_Dinv_y, Linv_Ws
def _triangular_solve(x, y, upper=False, transpose=False): assert np.ndim(x) >= 2 and np.ndim(y) >= 2 n, m = x.shape[-2:] assert y.shape[-2:] == (n, n) # NB: JAX requires x and y have the same batch_shape batch_shape = lax.broadcast_shapes(x.shape[:-2], y.shape[:-2]) x = np.broadcast_to(x, batch_shape + (n, m)) if y.shape[:-2] == batch_shape: return solve_triangular(y, x, trans=int(transpose), lower=not upper) # The following procedure handles the case: y.shape = (i, 1, n, n), x.shape = (..., i, j, n, m) # because we don't want to broadcast y to the shape (i, j, n, n). # We are going to make x have shape (..., 1, j, i, 1, n) to apply batched triangular_solve dx = x.ndim prepend_ndim = dx - y.ndim # ndim of ... part # Reshape x with the shape (..., 1, i, j, 1, n, m) x_new_shape = batch_shape[:prepend_ndim] for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]): x_new_shape += (sx // sy, sy) x_new_shape += ( n, m, ) x = np.reshape(x, x_new_shape) # Permute y to make it have shape (..., 1, j, m, i, 1, n) batch_ndim = x.ndim - 2 permute_dims = (tuple(range(prepend_ndim)) + tuple(range(prepend_ndim, batch_ndim, 2)) + (batch_ndim + 1, ) + tuple(range(prepend_ndim + 1, batch_ndim, 2)) + (batch_ndim, )) x = np.transpose(x, permute_dims) x_permute_shape = x.shape # reshape to (-1, i, 1, n) x = np.reshape(x, (-1, ) + y.shape[:-1]) # permute to (i, 1, n, -1) x = np.moveaxis(x, 0, -1) sol = solve_triangular(y, x, trans=int(transpose), lower=not upper) # shape: (i, 1, n, -1) sol = np.moveaxis(sol, -1, 0) # shape: (-1, i, 1, n) sol = np.reshape(sol, x_permute_shape) # shape: (..., 1, j, m, i, 1, n) # now we permute back to x_new_shape = (..., 1, i, j, 1, n, m) permute_inv_dims = tuple(range(prepend_ndim)) for i in range(y.ndim - 2): permute_inv_dims += (prepend_ndim + i, dx + i - 1) permute_inv_dims += (sol.ndim - 1, prepend_ndim + y.ndim - 2) sol = np.transpose(sol, permute_inv_dims) return sol.reshape(batch_shape + (n, m))
def log_likelihood(self, params): self.model.set_parameters(params) kx = self.model.kernel.function( self.model.X, params) + jnp.eye(self.N) * (params["noise"] + 1e-8) L = cholesky(kx, lower=True) alpha = solve_triangular(L.T, solve_triangular(L, self.model.y, lower=True)) W_logdet = 2. * jnp.sum(jnp.log(jnp.diag(L))) log_marginal = 0.5 * (-self.model.y.size * log_2_pi - self.model.y.shape[1] * W_logdet - jnp.sum(alpha * self.model.y)) return log_marginal
def _predict(self, xtest, full_covariance: bool = False, noiseless: bool = True): # Calculate the Mean K_x = self.kernel.cross_covariance(xtest, self.X) μ = jnp.dot(K_x, self.weights) # calculate covariance v = solve_triangular(self.Lff, K_x.T, lower=True) if full_covariance: K_xx = self.kernel.gram(xtest) if not noiseless: K_xx = add_to_diagonal(K_xx, self.obs_noise) Σ = K_xx - v.T @ v return μ, Σ else: K_xx = self.kernel.diag(xtest) σ = K_xx - jnp.sum(jnp.square(v), axis=0) if not noiseless: σ += self.obs_noise return μ, σ
def mll(params: dict, x: jnp.DeviceArray, y: jnp.DeviceArray, static_params: dict = None): params = transform(params) if static_params: params = concat_dictionaries(params, static_params) m = gp.prior.kernel.num_basis phi = gp.prior.kernel._build_phi(x, params) A = (params["variance"] / m) * jnp.matmul( jnp.transpose(phi), phi) + params["obs_noise"] * I(2 * m) RT = jnp.linalg.cholesky(A) R = jnp.transpose(RT) RtiPhit = solve_triangular(RT, jnp.transpose(phi)) # Rtiphity=RtiPhit*y_tr; Rtiphity = jnp.matmul(RtiPhit, y) out = (0.5 / params["obs_noise"] * (jnp.sum(jnp.square(y)) - params["variance"] / m * jnp.sum(jnp.square(Rtiphity)))) n = x.shape[0] out += (jnp.sum(jnp.log(jnp.diag(R))) + (n / 2.0 - m) * jnp.log(params["variance"]) + n / 2 * jnp.log(2 * jnp.pi)) constant = jnp.array(-1.0) if negative else jnp.array(1.0) return constant * out.reshape()
def final_fn(state, regularize=False): """ :param state: Current state of the scheme. :param bool regularize: Whether to adjust diagonal for numerical stability. :return: a triple of estimated covariance, the square root of precision, and the inverse of that square root. """ mean, m2, n = state # XXX it is not necessary to check for the case n=1 cov = m2 / (n - 1) if regularize: # Regularization from Stan scaled_cov = (n / (n + 5)) * cov shrinkage = 1e-3 * (5 / (n + 5)) if diagonal: cov = scaled_cov + shrinkage else: cov = scaled_cov + shrinkage * jnp.identity(mean.shape[0]) if jnp.ndim(cov) == 2: # copy the implementation of distributions.util.cholesky_of_inverse here tril_inv = jnp.swapaxes( jnp.linalg.cholesky(cov[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1) identity = jnp.identity(cov.shape[-1]) cov_inv_sqrt = solve_triangular(tril_inv, identity, lower=True) else: tril_inv = jnp.sqrt(cov) cov_inv_sqrt = jnp.reciprocal(tril_inv) return cov, cov_inv_sqrt, tril_inv
def get_cond_params(learned_params: dict, x: Array, y: Array, jitter: float = 1e-5) -> dict: params = deepcopy(learned_params) n_samples = x.shape[0] # calculate the cholesky factorization Kuu = rbf_kernel(params["x_u"], params["x_u"], params["variance"], params["length_scale"]) Kuu = add_to_diagonal(Kuu, jitter) Luu = cholesky(Kuu, lower=True) Kuf = rbf_kernel(params["x_u"], x, params["variance"], params["length_scale"]) W = solve_triangular(Luu, Kuf, lower=True) D = np.ones(n_samples) * params["obs_noise"] W_Dinv = W / D K = W_Dinv @ W.T K = add_to_diagonal(K, 1.0) L = cholesky(K, lower=True) # mean function y_residual = y # mean function y_2D = y_residual.reshape(-1, n_samples).T W_Dinv_y = W_Dinv @ y_2D return {"Luu": Luu, "W_Dinv_y": W_Dinv_y, "L": L}
def solve_via_cholesky(k_chol, y): """Solves a positive definite linear system via a Cholesky decomposition. Args: k_chol: The Cholesky factor of the matrix to solve. A lower triangular matrix, perhaps more commonly known as L. y: The vector to solve. """ # Solve Ls = y s = spl.solve_triangular(k_chol, y, lower=True) # Solve Lt b = s b = spl.solve_triangular(k_chol.T, s) return b
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma): P, N, M = X.shape[1], X.shape[0], len(active_dims) # the total number of coefficients we return num_coefficients = P + M * (M - 1) // 2 probe = jnp.zeros((2 * P + 2 * M * (M - 1), P)) vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1))) start1 = 0 start2 = 0 for dim in range(P): probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 2, dim], jnp.array([1.0, -1.0])) vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 2], jnp.array([0.5, -0.5])) start1 += 2 start2 += 1 for dim1 in active_dims: for dim2 in active_dims: if dim1 >= dim2: continue probe = jax.ops.index_update( probe, jax.ops.index[start1:start1 + 4, dim1], jnp.array([1.0, 1.0, -1.0, -1.0])) probe = jax.ops.index_update( probe, jax.ops.index[start1:start1 + 4, dim2], jnp.array([1.0, -1.0, 1.0, -1.0])) vec = jax.ops.index_update( vec, jax.ops.index[start2, start1:start1 + 4], jnp.array([0.25, -0.25, -0.25, 0.25])) start1 += 4 start2 += 1 eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam)) kX = kappa * X kprobe = kappa * probe k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N) L = cho_factor(k_xx, lower=True)[0] k_probeX = kernel(kprobe, kX, eta1, eta2, c) k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c) mu = jnp.matmul(k_probeX, cho_solve((L, True), Y)) mu = jnp.sum(mu * vec, axis=-1) Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True) covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX) covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec))) # sample from N(mu, covar) L = jnp.linalg.cholesky(covar) sample = mu + jnp.matmul(L, np.random.randn(num_coefficients)) return sample
def update_posterior(gram, XT_y, prior): alpha, beta = prior L = cholesky(jnp.diag(alpha) + beta * gram) R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True) sigma = jnp.dot(R.T, R) mean = beta * jnp.dot(sigma, XT_y) return mean, sigma
def cholesky_of_inverse(matrix): # This formulation only takes the inverse of a triangular matrix # which is more numerically stable. # Refer to: # https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril tril_inv = jnp.swapaxes(jnp.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1) identity = jnp.broadcast_to(jnp.identity(matrix.shape[-1]), tril_inv.shape) return solve_triangular(tril_inv, identity, lower=True)
def posterior_sample(self, key, sample, X_star, **kwargs): # Fetch training data batch = kwargs['batch'] XL, XH = batch['XL'], batch['XH'] NL, NH = XL.shape[0], XH.shape[0] # Fetch params var_L = sample['kernel_var_L'] var_H = sample['kernel_var_H'] length_L = sample['kernel_length_L'] length_H = sample['kernel_length_H'] beta_L = sample['beta_L'] beta_H = sample['beta_H'] eta_L = sample['eta_L'] eta_H = sample['eta_H'] rho = sample['rho'] theta_L = np.concatenate([var_L, length_L]) theta_H = np.concatenate([var_H, length_H]) beta = np.concatenate([beta_L*np.ones(NL), beta_H*np.ones(NH)]) eta = np.concatenate([eta_L, eta_H]) # Compute kernels k_pp = rho**2 * self.kernel(X_star, X_star, theta_L) + \ self.kernel(X_star, X_star, theta_H) + \ np.eye(X_star.shape[0])*1e-8 psi1 = rho*self.kernel(X_star, XL, theta_L) psi2 = rho**2 * self.kernel(X_star, XH, theta_L) + \ self.kernel(X_star, XH, theta_H) k_pX = np.hstack((psi1,psi2)) # Compute K_xx K_LL = self.kernel(XL, XL, theta_L) + np.eye(NL)*1e-8 K_LH = rho*self.kernel(XL, XH, theta_L) K_HH = rho**2 * self.kernel(XH, XH, theta_L) + \ self.kernel(XH, XH, theta_H) + np.eye(NH)*1e-8 K_xx = np.vstack((np.hstack((K_LL,K_LH)), np.hstack((K_LH.T,K_HH)))) L = cholesky(K_xx, lower=True) # Sample latent function f = np.matmul(L, eta) + beta tmp_1 = solve_triangular(L.T,solve_triangular(L, f, lower=True)) tmp_2 = solve_triangular(L.T,solve_triangular(L, k_pX.T, lower=True)) # Compute predictive mean mu = np.matmul(k_pX, tmp_1) cov = k_pp - np.matmul(k_pX, tmp_2) std = np.sqrt(np.clip(np.diag(cov), a_min=0.)) sample = mu + std * random.normal(key, mu.shape) return mu, sample
def build_rv(test_points: Array): Kfx = cross_covariance(gp.prior.kernel, X, test_points, params) Kxx = gram(gp.prior.kernel, test_points, params) A = solve_triangular(L, Kfx.T, lower=True) latent_var = Kxx - jnp.sum(jnp.square(A), -2) latent_mean = jnp.matmul(A.T, nu) lvar = jnp.diag(latent_var) moment_fn = predictive_moments(gp.likelihood) return moment_fn(latent_mean.ravel(), lvar)
def isPD_and_invert(M): L = np.linalg.cholesky(M) if np.isnan(np.sum(L)): return False, None L_inverse = sla.solve_triangular(L, np.eye(len(L)), lower=True, check_finite=False) return True, L_inverse.T.dot(L_inverse)
def random_variable( gp: SpectralPosterior, params: dict, train_inputs: Array, train_outputs: Array, test_inputs: Array, static_params: dict = None, ) -> tfd.Distribution: params = concat_dictionaries(params, static_params) m = gp.prior.kernel.num_basis w = params["basis_fns"] / params["lengthscale"] phi = gp.prior.kernel._build_phi(train_inputs, params) A = (params["variance"] / m) * jnp.matmul(jnp.transpose(phi), phi) + params["obs_noise"] * I( 2 * m ) RT = jnp.linalg.cholesky(A) R = jnp.transpose(RT) RtiPhit = solve_triangular(RT, jnp.transpose(phi)) # Rtiphity=RtiPhit*y_tr; Rtiphity = jnp.matmul(RtiPhit, train_outputs) alpha = params["variance"] / m * solve_triangular(R, Rtiphity, lower=False) phistar = jnp.matmul(test_inputs, jnp.transpose(w)) # phistar = [cos(phistar) sin(phistar)]; % test design matrix phistar = jnp.hstack([jnp.cos(phistar), jnp.sin(phistar)]) # out1(beg_chunk:end_chunk) = phistar*alfa; % Predictive mean mean = jnp.matmul(phistar, alpha) print(mean.shape) RtiPhistart = solve_triangular(RT, jnp.transpose(phistar)) PhiRistar = jnp.transpose(RtiPhistart) cov = ( params["obs_noise"] * params["variance"] / m * jnp.matmul(PhiRistar, jnp.transpose(PhiRistar)) + I(test_inputs.shape[0]) * 1e-6 ) return tfd.MultivariateNormalFullCovariance(mean.squeeze(), cov)
def log_mvnormal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) # maha = dx @ jnp.linalg.solve(cov, dx) maha = dx @ dx # logdet = jnp.log(jnp.linalg.det(cov)) logdet = jnp.sum(jnp.diag(L)) log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha return log_prob
def _batch_mahalanobis(bL, bx): if bL.shape[:-1] == bx.shape: # no need to use the below optimization procedure solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1) return np.sum(np.square(solve_bL_bx), -1) # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n) # because we don't want to broadcast bL to the shape (i, j, n, n). # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape out_shape = np.shape(bx)[:-1] # shape of output # Reshape bx with the shape (..., 1, i, j, 1, n) bx_new_shape = out_shape[:sample_ndim] for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]): bx_new_shape += (sx // sL, sL) bx_new_shape += (-1, ) bx = np.reshape(bx, bx_new_shape) # Permute bx to make it have shape (..., 1, j, i, 1, n) permute_dims = (tuple(range(sample_ndim)) + tuple(range(sample_ndim, bx.ndim - 1, 2)) + tuple(range(sample_ndim + 1, bx.ndim - 1, 2)) + (bx.ndim - 1, )) bx = np.transpose(bx, permute_dims) # reshape to (-1, i, 1, n) xt = np.reshape(bx, (-1, ) + bL.shape[:-1]) # permute to (i, 1, n, -1) xt = np.moveaxis(xt, 0, -1) solve_bL_bx = solve_triangular(bL, xt, lower=True) # shape: (i, 1, n, -1) M = np.sum(solve_bL_bx**2, axis=-2) # shape: (i, 1, -1) # permute back to (-1, i, 1) M = np.moveaxis(M, -1, 0) # reshape back to (..., 1, j, i, 1) M = np.reshape(M, bx.shape[:-1]) # permute back to (..., 1, i, j, 1) permute_inv_dims = tuple(range(sample_ndim)) for i in range(bL.ndim - 2): permute_inv_dims += (sample_ndim + i, len(out_shape) + i) M = np.transpose(M, permute_inv_dims) return np.reshape(M, out_shape)
def build_rv(test_points: Array): N = test_points.shape[0] phistar = jnp.matmul(test_points, jnp.transpose(w)) phistar = jnp.hstack([jnp.cos(phistar), jnp.sin(phistar)]) mean = jnp.matmul(phistar, alpha) RtiPhistart = solve_triangular(RT, jnp.transpose(phistar)) PhiRistar = jnp.transpose(RtiPhistart) cov = (params["obs_noise"] * params["variance"] / m * jnp.matmul(PhiRistar, jnp.transpose(PhiRistar)) + I(N) * 1e-6) return tfd.MultivariateNormalFullCovariance(mean.squeeze(), cov)
def precision_matrix(self): # We use "Woodbury matrix identity" to take advantage of low rank form:: # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) # where :math:`C` is the capacitance matrix. Wt_Dinv = (np.swapaxes(self.cov_factor, -1, -2) / np.expand_dims(self.cov_diag, axis=-2)) A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True) # TODO: find a better solution to create a diagonal matrix inverse_cov_diag = np.reciprocal(self.cov_diag) diag_embed = inverse_cov_diag[..., np.newaxis] * np.identity(self.loc.shape[-1]) return diag_embed - np.matmul(np.swapaxes(A, -1, -2), A)
def meanf(test_inputs: Array) -> Array: Kfx = cross_covariance(gp.prior.kernel, X, test_inputs, param) Kxx = gram(gp.prior.kernel, test_inputs, param) A = solve_triangular(L, Kfx.T, lower=True) latent_var = Kxx - jnp.sum(jnp.square(A), -2) latent_mean = jnp.matmul(A.T, nu) lvar = jnp.diag(latent_var) moment_fn = predictive_moments(gp.likelihood) pred_rv = moment_fn(latent_mean.ravel(), lvar) return pred_rv.mean()
def predict(self, Xnew): Kx = self.model.kernel.cov(self.model.X, Xnew) mu = jnp.dot(Kx.T, self.alpha) Kxx = self.model.kernel.cov(Xnew, Xnew) tmp = solve_triangular(self.L, Kx, lower=True) var = Kxx - jnp.dot(tmp.T, tmp) + jnp.eye(Xnew.shape[0]) * self.model.variance return mu, var
def sample_initial_states(rng, data, num_chain=4, algorithm="chmc"): """Sample initial states from prior.""" init_states = [] for _ in range(num_chain): u = sample_from_prior(rng, data) if algorithm == "chmc": chol_covar = onp.linalg.cholesky(covar_func(u, data)) n = sla.solve_triangular(chol_covar, data["y_obs"], lower=True) q = onp.concatenate((u, onp.asarray(n))) else: q = onp.asarray(u) init_states.append(q) return init_states