def nll(self, loghyperparams): """Negative log likelihood Parameters ---------- loghyperparams : ndarray array with logarithm of hyperparameters Returns ------- nll : float negative log likelihood (up to a scale and offset) """ hyperparams = jnp.exp(loghyperparams) lambdam = self.getlambda(hyperparams) sn = hyperparams[-1] # build the pd Z-matrix Z = self.PhiTPhi + jnp.diag(sn / lambdam) # use cholesky for numerical stability Zchol, low = jscl.cho_factor(Z) ZiPhiT = jscl.cho_solve((Zchol, low), self.Phi.T) # compute the log likelihood components logQ = (self.nobs-self.m)*loghyperparams[-1] + \ 2*jnp.sum(jnp.log(jnp.diag(Zchol))) + jnp.sum(jnp.log(lambdam)) yTinvQy = 1 / sn * self.y @ (self.y - self.Phi @ (ZiPhiT @ self.y)) return logQ + yTinvQy
def compute_log_lik(self, pseudo_y=None, pseudo_var=None): """ Compute the log marginal likelihood of the pseudo model, i.e. the log normaliser of the approximate posterior """ dim = 1 # TODO: implement multivariate case # TODO: won't match MarkovGP for batching with missings or for multidim input X = self.X[self.obs_ind] # only compute log lik for observed values # TODO: check use of obs_ind (remove?) if pseudo_y is None: pseudo_y = self.pseudo_y.value pseudo_var = self.pseudo_var.value pseudo_y = pseudo_y[self.obs_ind] pseudo_var = pseudo_var[self.obs_ind] Knn = self.kernel(X, X) Ky = Knn + np.diag(np.squeeze(pseudo_var)) # TODO: this will break for multi-latents # ---- compute the marginal likelihood, i.e. the normaliser, of the pseudo model ---- pseudo_y = diag(pseudo_y) Ly, low = cho_factor(Ky) log_lik_pseudo = ( - 0.5 * np.sum(pseudo_y.T @ cho_solve((Ly, low), pseudo_y)) - np.sum(np.log(np.diag(Ly))) - 0.5 * pseudo_y.shape[0] * dim * LOG2PI ) return log_lik_pseudo
def fit(self, X, Y, rng_key, n_step): self.X_train = X # store moments of training y (to normalize) self.y_mean = jnp.mean(Y) self.y_std = jnp.std(Y) # normalize y Y = (Y - self.y_mean) / self.y_std # setup optimizer and SVI optim = numpyro.optim.Adam(step_size=0.005, b1=0.5) svi = SVI( model, guide=AutoDelta(model), optim=optim, loss=Trace_ELBO(), X=X, Y=Y, ) params, _ = svi.run(rng_key, n_step) # get kernel parameters from guide with proper names self.kernel_params = svi.guide.median(params) # store cholesky factor of prior covariance self.L = linalg.cho_factor(self.kernel(X, X, **self.kernel_params)) # store inverted prior covariance multiplied by y self.alpha = linalg.cho_solve(self.L, Y) return self.kernel_params
def gaussian_conditional(kernel, y, noise_cov, X, X_star=None): """ Compute the GP posterior / predictive distribution using standard Gaussian identities :param kernel: an instantiation of the kernel class :param y: observations [N, 1] :param noise_cov: observation noise covariance [N, 1] :param X: training inputs [N, D] :param X_star: test inputs [N*, D] :return: mean: posterior mean [N, 1] covariance: posterior covariance [N, N] """ Kff = kernel(X, X) if X_star is None: # inference / learning Kfs = Kff Kss = Kff else: # prediction Kfs = kernel(X, X_star) Kss = kernel(X_star, X_star) Ky = Kff + np.diag( np.squeeze(noise_cov)) # TODO: will break for multi-latents # ---- compute approximate posterior using standard Gaussian conditional formula ---- Ly, low = cho_factor(Ky) Kfs_iKy = cho_solve((Ly, low), Kfs).T mean = Kfs_iKy @ diag(y) covariance = Kss - Kfs_iKy @ Kfs return mean, covariance
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma): P, N, M = X.shape[1], X.shape[0], len(active_dims) # the total number of coefficients we return num_coefficients = P + M * (M - 1) // 2 probe = jnp.zeros((2 * P + 2 * M * (M - 1), P)) vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1))) start1 = 0 start2 = 0 for dim in range(P): probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 2, dim], jnp.array([1.0, -1.0])) vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 2], jnp.array([0.5, -0.5])) start1 += 2 start2 += 1 for dim1 in active_dims: for dim2 in active_dims: if dim1 >= dim2: continue probe = jax.ops.index_update( probe, jax.ops.index[start1:start1 + 4, dim1], jnp.array([1.0, 1.0, -1.0, -1.0])) probe = jax.ops.index_update( probe, jax.ops.index[start1:start1 + 4, dim2], jnp.array([1.0, -1.0, 1.0, -1.0])) vec = jax.ops.index_update( vec, jax.ops.index[start2, start1:start1 + 4], jnp.array([0.25, -0.25, -0.25, 0.25])) start1 += 4 start2 += 1 eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam)) kX = kappa * X kprobe = kappa * probe k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N) L = cho_factor(k_xx, lower=True)[0] k_probeX = kernel(kprobe, kX, eta1, eta2, c) k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c) mu = jnp.matmul(k_probeX, cho_solve((L, True), Y)) mu = jnp.sum(mu * vec, axis=-1) Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True) covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX) covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec))) # sample from N(mu, covar) L = jnp.linalg.cholesky(covar) sample = mu + jnp.matmul(L, np.random.randn(num_coefficients)) return sample
def inducing_precision(self): """ Compute the covariance and precision of the inducing spatial points to be used during filtering """ Kzz = self.spatial_kernel(self.z.value, self.z.value) Lzz, low = cho_factor(Kzz, lower=True) # K_zz^(1/2) Qzz = cho_solve((Lzz, low), np.eye(self.M)) # K_zz^(-1) return Qzz, Lzz
def _project_to_f_given_kernels_old(knn, kmm, knm, m, S, diag_only): kmm_chol, lower = cho_factor(kmm) pred_mean = knm @ cho_solve((kmm_chol, lower), m) D = cho_solve((kmm_chol, lower), S) - jnp.eye(m.shape[0]) B = (cho_solve((kmm_chol, lower), D.T)).T if diag_only: cov = knn + diag_elts_of_triple_matmul(knm, B, knm.T) else: cov = knn + knm @ B @ knm.T return jnp.squeeze(pred_mean), cov
def cholesky_factorization(K: Array, Y: Array) -> Tuple[Array, bool]: """Cholesky Factorization""" L = cho_factor(K, lower=True) # weights # print(L.shape, Y.shape) weights = cho_solve(L, Y) return L, weights
def lmult_by_inv_gram(dc_du, dc_dv, chol_C, chol_D, vct): """Left-multiply vector by inverse Gram matrix.""" vct_parts = split( vct, (dc_du[0].shape[0], dc_du[1].shape[0] * dc_du[1].shape[1])) vct_parts[1] = np.reshape(vct_parts[1], dc_du[1].shape[:2]) D_inv_vct = [ sla.cho_solve((chol_D[i], True), vct_parts[i]) for i in range(3) ] dc_du_T_D_inv_vct = sum( np.einsum('...jk,...j->k', dc_du[i], D_inv_vct[i]) for i in range(3)) C_inv_dc_du_T_D_inv_vct = sla.cho_solve((chol_C, True), dc_du_T_D_inv_vct) return np.concatenate([ sla.cho_solve((chol_D[i], True), vct_parts[i] - dc_du[i] @ C_inv_dc_du_T_D_inv_vct).flatten() for i in range(3) ])
def compute_log_lik(self, pseudo_y=None, pseudo_var=None): """ log int p(u) prod_n N(pseudo_y_n | u, pseudo_var_n) du """ dim = 1 # TODO: implement multivariate case Kuu = self.kernel(self.Z.value, self.Z.value) Ky = Kuu + pseudo_var Ly, low = cho_factor(Ky) log_lik_pseudo = ( # this term depends on the prior - 0.5 * np.sum(pseudo_y.T @ cho_solve((Ly, low), pseudo_y)) - np.sum(np.log(np.diag(Ly))) - 0.5 * pseudo_y.shape[0] * dim * LOG2PI ) return log_lik_pseudo
def update(self, likelihood, y, post_mean, post_cov, hyp=None, site_params=None): """ The update function takes a likelihood as input, and uses analytical linearisation (first order Taylor series expansion) to update the site parameters """ power = 1. if site_params is None else self.power if (site_params is None) or (power == 0): # avoid cavity calc if power is 0 cav_mean, cav_cov = post_mean, post_cov else: site_mean, site_cov = site_params # --- Compute the cavity distribution --- cav_mean, cav_cov = compute_cavity(post_mean, post_cov, site_mean, site_cov, power) # calculate the Jacobian of the observation model w.r.t. function fₙ and noise term rₙ Jf, Jsigma = likelihood.analytical_linearisation( cav_mean, np.zeros_like(y), hyp) # evaluate at mean obs_cov = np.eye(y.shape[0]) # observation noise scale is w.l.o.g. 1 likelihood_expectation, _ = likelihood.conditional_moments( cav_mean, hyp) residual = y - likelihood_expectation # residual, yₙ-E[yₙ|fₙ] sigma = Jsigma @ obs_cov @ Jsigma.T + power * Jf @ cav_cov @ Jf.T site_nat2 = Jf.T @ inv(Jsigma @ obs_cov @ Jsigma.T) @ Jf site_cov = inv(site_nat2 + 1e-10 * np.eye(Jf.shape[1])) site_mean = cav_mean + (site_cov + power * cav_cov) @ Jf.T @ inv(sigma) @ residual # now compute the marginal likelihood approx. sigma_marg_lik = Jsigma @ obs_cov @ Jsigma.T + Jf @ cav_cov @ Jf.T chol_sigma, low = cho_factor(sigma_marg_lik) log_marg_lik = -1 * (.5 * site_cov.shape[0] * np.log(2 * pi) + np.sum(np.log(np.diag(chol_sigma))) + .5 * (residual.T @ cho_solve( (chol_sigma, low), residual))) if (site_params is not None) and (self.damping != 1.): site_mean_prev, site_cov_prev = site_params # previous site params site_nat2_prev = inv(site_cov_prev + 1e-10 * np.eye(Jf.shape[1])) site_nat1, site_nat1_prev = site_nat2 @ site_mean, site_nat2_prev @ site_mean_prev site_cov = inv((1. - self.damping) * site_nat2_prev + self.damping * site_nat2 + 1e-10 * np.eye(Jf.shape[1])) site_mean = site_cov @ ((1. - self.damping) * site_nat1_prev + self.damping * site_nat1) return log_marg_lik, site_mean, site_cov
def _linear_regression_gibbs_fn(X, XX, XY, Y, rng_key, gibbs_sites, hmc_sites): N, P = X.shape sigma = jnp.exp(hmc_sites['log_sigma']) if 'log_sigma' in hmc_sites else hmc_sites['sigma'] sigma_sq = jnp.square(sigma) covar_inv = XX / sigma_sq + jnp.eye(P) L = cho_factor(covar_inv, lower=True)[0] L_inv = solve_triangular(L, jnp.eye(P), lower=True) loc = cho_solve((L, True), XY) / sigma_sq beta_proposal = dist.MultivariateNormal(loc=loc, scale_tril=L_inv).sample(rng_key) return {'beta': beta_proposal}
def mean(gp: ConjugatePosterior, param: dict, training: Dataset) -> Callable: X, y = training.X, training.y sigma = param["obs_noise"] n_train = training.n # Precompute covariance matrices Kff = gram(gp.prior.kernel, X, param) prior_mean = gp.prior.mean_function(X) L = cho_factor(Kff + I(n_train) * sigma, lower=True) prior_distance = y - prior_mean weights = cho_solve(L, prior_distance) def meanf(test_inputs: Array) -> Array: prior_mean_at_test_inputs = gp.prior.mean_function(test_inputs) Kfx = cross_covariance(gp.prior.kernel, X, test_inputs, param) return prior_mean_at_test_inputs + jnp.dot(Kfx, weights) return meanf
def predict(self, X, return_std=False): # compute kernels between train and test data, etc. k_pp = self.kernel(X, X, **self.kernel_params) k_pX = self.kernel(X, self.X_train, **self.kernel_params, jitter=0.0) # compute posterior covariance K = k_pp - k_pX @ linalg.cho_solve(self.L, k_pX.T) # compute posterior mean mean = k_pX @ self.alpha # we return both the mean function and the standard deviation if return_std: return ( (mean * self.y_std) + self.y_mean, jnp.sqrt(jnp.diag(K * self.y_std**2)), ) else: return (mean * self.y_std) + self.y_mean, K * self.y_std**2
def _newton_iteration(y_train, K, f): pi = expit(f) W = pi * (1 - pi) # Line 5 W_sr = np.sqrt(W) W_sr_K = W_sr[:, np.newaxis] * K B = np.eye(W.shape[0]) + W_sr_K * W_sr L = cholesky(B, lower=True) # Line 6 b = W * f + (y_train - pi) # Line 7 a = b - W_sr * cho_solve((L, True), W_sr_K.dot(b)) # Line 8 f = K.dot(a) # Line 10: Compute log marginal likelihood in loop and use as # convergence criterion lml = -0.5 * a.T.dot(f) \ - np.log1p(np.exp(-(y_train * 2 - 1) * f)).sum() \ - np.log(np.diag(L)).sum() return lml, f, (pi, W_sr, L, b, a)
def variance( gp: ConjugatePosterior, param: dict, test_inputs: Array, train_inputs: Array, train_outputs: Array, ) -> Array: assert ( train_outputs.ndim == 2 ), f"2-dimensional training outputs are required. Current dimensional: {train_outputs.ndim}." ell, alpha = param["lengthscale"], param["variance"] sigma = param["obs_noise"] n_train = train_inputs.shape[0] Kff = gram(gp.prior.kernel, train_inputs, param) Kfx = cross_covariance(gp.prior.kernel, train_inputs, test_inputs, param) Kxx = gram(gp.prior.kernel, test_inputs, param) L = cho_factor(Kff + I(n_train) * sigma, lower=True) latents = cho_solve(L, Kfx.T) return Kxx - jnp.dot(Kfx, latents)
def mvn_logpdf(x, mean, cov, mask=None): """ evaluate a multivariate Gaussian (log) pdf """ if mask is not None: # build a mask for computing the log likelihood of a partially observed multivariate Gaussian maskv = mask.reshape(-1, 1) mean = np.where(maskv, x, mean) cov_masked = np.where(maskv + maskv.T, 0., cov) # ensure masked entries are independent cov = np.where( np.diag(mask), INV2PI, cov_masked) # ensure masked entries return log like of 0 n = mean.shape[0] cho, low = cho_factor(cov) log_det = 2 * np.sum(np.log(np.abs(np.diag(cho)))) diff = x - mean scaled_diff = cho_solve((cho, low), diff) distance = diff.T @ scaled_diff return np.squeeze(-0.5 * (distance + n * LOG2PI + log_det))
def mean( gp: ConjugatePosterior, param: dict, test_inputs: Array, train_inputs: Array, train_outputs: Array, ): assert ( train_outputs.ndim == 2 ), f"2-dimensional training outputs are required. Current dimensional: {train_outputs.ndim}." ell, alpha = param["lengthscale"], param["variance"] sigma = param["obs_noise"] n_train = train_inputs.shape[0] Kff = gram(gp.prior.kernel, train_inputs, param) Kfx = cross_covariance(gp.prior.kernel, train_inputs, test_inputs, param) prior_mean = gp.prior.mean_function(train_inputs) L = cho_factor(Kff + I(n_train) * sigma, lower=True) prior_distance = train_outputs - prior_mean weights = cho_solve(L, prior_distance) return jnp.dot(Kfx, weights)
def get_cond_params(kernel, params: dict, x: Array, y: Array, jitter: float = 1e-5) -> dict: params = deepcopy(params) obs_noise = params.pop("obs_noise") kernel = kernel(**params) # calculate the cholesky factorization Lff = precompute(x, obs_noise, kernel, jitter=jitter) weights = cho_solve((Lff, True), y) return { "X": jnp.array(x), "y": jnp.array(y), "Lff": jnp.array(Lff), "obs_noise": jnp.array(obs_noise), "kernel": kernel, "weights": jnp.array(weights), }
def chol_gram_blocks(dc_du, dc_dv): """Calculate Cholesky factors of decomposition of Gram matrix. """ if isinstance(metric, IdentityMatrix): D = tuple( np.einsum('...ij,...kj', dc_dv[i], dc_dv[i]) for i in range(3)) else: m_v = split( metric_2_diag, (dc_dv[0].shape[1], dc_dv[1].shape[0] * dc_dv[1].shape[2])) m_v[1] = m_v[1].reshape((dc_dv[1].shape[0], dc_dv[1].shape[2])) D = tuple( np.einsum('...ij,...kj', dc_dv[i] / m_v[i][..., None, :], dc_dv[i]) for i in range(3)) chol_D = tuple(nla.cholesky(D[i]) for i in range(3)) D_inv_dc_du = tuple( sla.cho_solve((chol_D[i], True), dc_du[i]) for i in range(3)) chol_C = nla.cholesky(metric_1 + ( dc_du[0].T @ D_inv_dc_du[0] + np.einsum('ijk,ijl->kl', dc_du[1], D_inv_dc_du[1]) + dc_du[2].T @ D_inv_dc_du[2])) return chol_C, chol_D
def predict(params, Xtest, noiseless: bool = False): # weights K = rbf_kernel(X, X, params["variance"], params["length_scale"]) K = add_to_diagonal(K, params["obs_noise"]) L, alpha = cholesky_factorization(K, y) # projection kernel K_xf = rbf_kernel(Xtest, X, params["variance"], params["length_scale"]) # dot product (mean predictions) mu_y = K_xf @ alpha # covariance v = cho_solve(L, K_xf.T) K_xx = rbf_kernel(Xtest, Xtest, params["variance"], params["length_scale"]) cov_y = K_xx - jnp.dot(K_xf, v) if not noiseless: cov_y = add_to_diagonal(cov_y, params["obs_noise"]) return mu_y, cov_y
def mvn_logpdf_and_derivs(x, mean, cov, mask=None): """ evaluate a multivariate Gaussian (log) pdf and compute its derivatives w.r.t. the mean """ if mask is not None: # build a mask for computing the log likelihood of a partially observed multivariate Gaussian maskv = mask.reshape(-1, 1) mean = np.where(maskv, x, mean) cov_masked = np.where(maskv + maskv.T, 0., cov) # ensure masked entries are independent cov = np.where( np.diag(mask), INV2PI, cov_masked) # ensure masked entries return log like of 0 n = mean.shape[0] cho, low = cho_factor(cov) precision = cho_solve((cho, low), np.eye(cho.shape[1])) # second derivative log_det = 2 * np.sum(np.log(np.abs(np.diag(cho)))) diff = x - mean scaled_diff = precision @ diff # first derivative distance = diff.T @ scaled_diff return np.squeeze( -0.5 * (distance + n * LOG2PI + log_det)), scaled_diff, -precision
def compute_conditional_statistics(x_test, x, kernel, ind): """ This version uses cho_factor and cho_solve - much more efficient when using JAX Predicts marginal states at new time points. (new time points should be sorted) Calculates the conditional density: p(xₙ|u₋, u₊) = 𝓝(Pₙ @ [u₋, u₊], Tₙ) :param x_test: time points to generate observations for [N] :param x: inducing state input locations [M] :param kernel: prior object providing access to state transition functions :param ind: an array containing the index of the inducing state to the left of every input [N] :return: parameters for the conditional mean and covariance P: [N, D, 2*D] T: [N, D, D] """ dt_fwd = x_test[..., 0] - x[ind, 0] dt_back = x[ind + 1, 0] - x_test[..., 0] A_fwd = kernel.state_transition(dt_fwd) A_back = kernel.state_transition(dt_back) Pinf = kernel.stationary_covariance() Q_fwd = Pinf - A_fwd @ Pinf @ A_fwd.T Q_back = Pinf - A_back @ Pinf @ A_back.T A_back_Q_fwd = A_back @ Q_fwd Q_mp = Q_back + A_back @ A_back_Q_fwd.T jitter = 1e-8 * np.eye(Q_mp.shape[0]) chol_Q_mp = cho_factor(Q_mp + jitter) Q_mp_inv_A_back = cho_solve(chol_Q_mp, A_back) # V = Q₋₊⁻¹ Aₜ₊ # The conditional_covariance T = Q₋ₜ - Q₋ₜAₜ₊ᵀQ₋₊⁻¹Aₜ₊Q₋ₜ == Q₋ₜ - Q₋ₜᵀAₜ₊ᵀL⁻ᵀL⁻¹Aₜ₊Q₋ₜ T = Q_fwd - A_back_Q_fwd.T @ Q_mp_inv_A_back @ Q_fwd # W = Q₋ₜAₜ₊ᵀQ₋₊⁻¹ W = Q_fwd @ Q_mp_inv_A_back.T P = np.concatenate([A_fwd - W @ A_back @ A_fwd, W], axis=-1) return P, T
def precision_matrix(self): identity = np.broadcast_to(np.eye(self.scale_tril.shape[-1]), self.scale_tril.shape) return cho_solve((self.scale_tril, True), identity)
def posterior_neg_log_dens(u, data): covar = covar_func(u, data) chol_covar = np.linalg.cholesky(covar) return prior_neg_log_dens(u, data) + (data["y_obs"] @ sla.cho_solve( (chol_covar, True), data["y_obs"]) / 2 + np.log(chol_covar.diagonal()).sum())
def _gaussian_kernel_convolve(chol, norm, target, weights, mean): diff = target - mean[:, None] alpha = linalg.cho_solve(chol, diff) arg = 0.5 * jnp.sum(diff * alpha, axis=0) return norm * jnp.sum(jnp.exp(-arg) * weights)
def init_gp_predictive(kernel, params: dict, x: Array, y: Array, jitter: float = 1e-5) -> dict: params = deepcopy(params) obs_noise = params.pop("obs_noise") kernel = kernel(**params) # calculate the cholesky factorization Lff = precompute(x, obs_noise, kernel, jitter=jitter) weights = cho_solve((Lff, True), y) def predict_mean(xtest): K_x = kernel.cross_covariance(xtest, x) μ = jnp.dot(K_x, weights) return μ def predict_cov(xtest, noiseless=False): # Calculate the Mean K_x = kernel.cross_covariance(xtest, x) v = solve_triangular(Lff, K_x.T, lower=True) K_xx = kernel.gram(xtest) Σ = K_xx - v.T @ v if not noiseless: Σ = add_to_diagonal(Σ, obs_noise) return Σ def _predict(xtest, full_covariance: bool = False, noiseless: bool = True): # Calculate the Mean K_x = kernel.cross_covariance(xtest, x) μ = jnp.dot(K_x, weights) # calculate covariance v = solve_triangular(Lff, K_x.T, lower=True) if full_covariance: K_xx = kernel.gram(xtest) if not noiseless: K_xx = add_to_diagonal(K_xx, obs_noise) Σ = K_xx - v.T @ v return μ, Σ else: K_xx = kernel.diag(xtest) σ = K_xx - jnp.sum(jnp.square(v), axis=0) if not noiseless: σ += obs_noise return μ, σ[:, None] def predict_var(xtest, noiseless=False): # Calculate the Mean K_x = kernel.cross_covariance(xtest, x) v = solve_triangular(Lff, K_x.T, lower=True) K_xx = kernel.diag(xtest) σ = K_xx - jnp.sum(jnp.square(v), axis=0) if not noiseless: σ += obs_noise return σ[:, None] def predict_f(xtest, full_covariance: bool = False): return _predict(xtest, full_covariance=full_covariance, noiseless=True) def predict_y(xtest, full_covariance: bool = False): return _predict(xtest, full_covariance=full_covariance, noiseless=False) return PredictTuple( predict_var=predict_var, predict_mean=predict_mean, predict_cov=predict_cov, predict_f=predict_f, predict_y=predict_y, )
def varf(test_inputs: Array) -> Array: Kfx = cross_covariance(gp.prior.kernel, X, test_inputs, param) Kxx = gram(gp.prior.kernel, test_inputs, param) latents = cho_solve(L, Kfx.T) return Kxx - jnp.dot(Kfx, latents)
def moment_match_cubature(self, y, cav_mean, cav_cov, hyp=None, power=1.0, cubature_func=None): """ TODO: N.B. THIS VERSION IS SUPERCEDED BY THE FUNCTION BELOW. HOWEVER THIS ONE MAY BE MORE STABLE. Perform moment matching via cubature. Moment matching invloves computing the log partition function, logZₙ, and its derivatives w.r.t. the cavity mean logZₙ = log ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ with EP power a. :param y: observed data (yₙ) [scalar] :param cav_mean: cavity mean (mₙ) [scalar] :param cav_cov: cavity covariance (cₙ) [scalar] :param hyp: likelihood hyperparameter [scalar] :param power: EP power / fraction (a) [scalar] :param cubature_func: the function to compute sigma points and weights to use during cubature :return: lZ: the log partition function, logZₙ [scalar] dlZ: first derivative of logZₙ w.r.t. mₙ (if derivatives=True) [scalar] d2lZ: second derivative of logZₙ w.r.t. mₙ (if derivatives=True) [scalar] """ if cubature_func is None: x, w = gauss_hermite(cav_mean.shape[0], 20) # Gauss-Hermite sigma points and weights else: x, w = cubature_func(cav_mean.shape[0]) cav_cho, low = cho_factor(cav_cov) # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity dist. sigma_points = cav_cho @ np.atleast_2d(x) + cav_mean # pre-compute wᵢ pᵃ(yₙ|xᵢ√(2vₙ) + mₙ) weighted_likelihood_eval = w * self.evaluate_likelihood( y, sigma_points, hyp)**power # a different approach, based on the log-likelihood, which can be more stable: # ll = self.evaluate_log_likelihood(y, sigma_points) # lmax = np.max(ll) # weighted_likelihood_eval = np.exp(lmax * power) * w * np.exp(power * (ll - lmax)) # Compute partition function via cubature: # Zₙ = ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ pᵃ(yₙ|fsigᵢ) Z = np.sum(weighted_likelihood_eval, axis=-1) lZ = np.log(Z) Zinv = 1.0 / Z # Compute derivative of partition function via cubature: # dZₙ/dmₙ = ∫ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fsigᵢ) covinv_f_m = cho_solve((cav_cho, low), sigma_points - cav_mean) dZ = np.sum( # (sigma_points - cav_mean) / cav_cov covinv_f_m * weighted_likelihood_eval, axis=-1) # dlogZₙ/dmₙ = (dZₙ/dmₙ) / Zₙ dlZ = Zinv * dZ # Compute second derivative of partition function via cubature: # d²Zₙ/dmₙ² = ∫ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fsigᵢ) d2Z = np.sum( ((sigma_points - cav_mean)**2 / cav_cov**2 - 1.0 / cav_cov) * weighted_likelihood_eval) # d²logZₙ/dmₙ² = d[(dZₙ/dmₙ) / Zₙ]/dmₙ # = (d²Zₙ/dmₙ² * Zₙ - (dZₙ/dmₙ)²) / Zₙ² # = d²Zₙ/dmₙ² / Zₙ - (dlogZₙ/dmₙ)² d2lZ = -dlZ @ dlZ.T + Zinv * d2Z id2lZ = inv( ensure_positive_precision(-d2lZ) - 1e-10 * np.eye(d2lZ.shape[0])) site_mean = cav_mean + id2lZ @ dlZ # approx. likelihood (site) mean (see Rasmussen & Williams p75) site_cov = power * (-cav_cov + id2lZ ) # approx. likelihood (site) variance return lZ, site_mean, site_cov
def _cholesky_solve(x, y): return cho_solve((y, True), x)