def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter): sigma2 = at.square(sigma) Kuu = self.cov_func(Xu) Kuf = self.cov_func(Xu, X) Luu = cholesky(stabilize(Kuu, jitter)) A = solve_lower(Luu, Kuf) Qffd = at.sum(A * A, 0) if self.approx == "FITC": Kffd = self.cov_func(X, diag=True) Lamd = at.clip(Kffd - Qffd, 0.0, np.inf) + sigma2 trace = 0.0 elif self.approx == "VFE": Lamd = at.ones_like(Qffd) * sigma2 trace = (1.0 / (2.0 * sigma2)) * (at.sum(self.cov_func(X, diag=True)) - at.sum(at.sum(A * A, 0))) else: # DTC Lamd = at.ones_like(Qffd) * sigma2 trace = 0.0 A_l = A / Lamd L_B = cholesky(at.eye(Xu.shape[0]) + at.dot(A_l, at.transpose(A))) r = y - self.mean_func(X) r_l = r / Lamd c = solve_lower(L_B, at.dot(A, r_l)) constant = 0.5 * X.shape[0] * at.log(2.0 * np.pi) logdet = 0.5 * at.sum(at.log(Lamd)) + at.sum(at.log(at.diag(L_B))) quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c)) return -1.0 * (constant + logdet + quadratic + trace)
def _build_conditional(self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total, jitter): sigma2 = at.square(sigma) Kuu = cov_total(Xu) Kuf = cov_total(Xu, X) Luu = cholesky(stabilize(Kuu, jitter)) A = solve_lower(Luu, Kuf) Qffd = at.sum(A * A, 0) if self.approx == "FITC": Kffd = cov_total(X, diag=True) Lamd = at.clip(Kffd - Qffd, 0.0, np.inf) + sigma2 else: # VFE or DTC Lamd = at.ones_like(Qffd) * sigma2 A_l = A / Lamd L_B = cholesky(at.eye(Xu.shape[0]) + at.dot(A_l, at.transpose(A))) r = y - mean_total(X) r_l = r / Lamd c = solve_lower(L_B, at.dot(A, r_l)) Kus = self.cov_func(Xu, Xnew) As = solve_lower(Luu, Kus) mu = self.mean_func(Xnew) + at.dot(at.transpose(As), solve_upper(at.transpose(L_B), c)) C = solve_lower(L_B, As) if diag: Kss = self.cov_func(Xnew, diag=True) var = Kss - at.sum(at.square(As), 0) + at.sum(at.square(C), 0) if pred_noise: var += sigma2 return mu, var else: cov = self.cov_func(Xnew) - at.dot(at.transpose(As), As) + at.dot( at.transpose(C), C) if pred_noise: cov += sigma2 * at.identity_like(cov) return mu, cov if pred_noise else stabilize(cov, jitter)
def _build_prior(self, name, Xs, jitter, **kwargs): self.N = int(np.prod([len(X) for X in Xs])) mu = self.mean_func(cartesian(*Xs)) chols = [cholesky(stabilize(cov(X), jitter)) for cov, X in zip(self.cov_funcs, Xs)] v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=self.N, **kwargs) f = pm.Deterministic(name, mu + at.flatten(kron_dot(chols, v))) return f
def _build_conditional(self, Xnew, X, f, cov_total, mean_total, jitter): Kxx = cov_total(X) Kxs = self.cov_func(X, Xnew) L = cholesky(stabilize(Kxx, jitter)) A = solve_lower(L, Kxs) v = solve_lower(L, f - mean_total(X)) mu = self.mean_func(Xnew) + at.dot(at.transpose(A), v) Kss = self.cov_func(Xnew) cov = Kss - at.dot(at.transpose(A), A) return mu, cov
def _build_prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs): mu = self.mean_func(X) cov = stabilize(self.cov_func(X), jitter) if reparameterize: size = infer_size(X, kwargs.pop("size", None)) v = pm.StudentT(name + "_rotated_", mu=0.0, sigma=1.0, nu=self.nu, size=size, **kwargs) f = pm.Deterministic(name, mu + cholesky(cov).dot(v)) else: f = pm.MvStudentT(name, nu=self.nu, mu=mu, cov=cov, **kwargs) return f
def _build_prior(self, name, X, reparameterize=True, **kwargs): mu = self.mean_func(X) cov = stabilize(self.cov_func(X)) shape = infer_shape(X, kwargs.pop("shape", None)) if reparameterize: v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=shape, **kwargs) f = pm.Deterministic(name, mu + cholesky(cov).dot(v)) else: f = pm.MvNormal(name, mu=mu, cov=cov, size=shape, **kwargs) return f
def _build_prior(self, name, X, reparameterize=True, **kwargs): mu = self.mean_func(X) cov = stabilize(self.cov_func(X)) shape = infer_shape(X, kwargs.pop("shape", None)) if reparameterize: chi2 = pm.ChiSquared(name + "_chi2_", self.nu) v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=shape, **kwargs) f = pm.Deterministic(name, (at.sqrt(self.nu) / chi2) * (mu + cholesky(cov).dot(v))) else: f = pm.MvStudentT(name, nu=self.nu, mu=mu, cov=cov, size=shape, **kwargs) return f
def _build_conditional(self, Xnew, X, f, jitter): Kxx = self.cov_func(X) Kxs = self.cov_func(X, Xnew) Kss = self.cov_func(Xnew) L = cholesky(stabilize(Kxx, jitter)) A = solve_lower(L, Kxs) cov = Kss - at.dot(at.transpose(A), A) v = solve_lower(L, f - self.mean_func(X)) mu = self.mean_func(Xnew) + at.dot(at.transpose(A), v) beta = at.dot(v, v) nu2 = self.nu + X.shape[0] covT = (self.nu + beta - 2) / (nu2 - 2) * cov return nu2, mu, covT
def _build_conditional(self, Xnew, jitter): Xs, f = self.Xs, self.f X = cartesian(*Xs) delta = f - self.mean_func(X) covs = [stabilize(cov(Xi), jitter) for cov, Xi in zip(self.cov_funcs, Xs)] chols = [cholesky(cov) for cov in covs] cholTs = [at.transpose(chol) for chol in chols] Kss = self.cov_func(Xnew) Kxs = self.cov_func(X, Xnew) Ksx = at.transpose(Kxs) alpha = kron_solve_lower(chols, delta) alpha = kron_solve_upper(cholTs, alpha) mu = at.dot(Ksx, alpha).ravel() + self.mean_func(Xnew) A = kron_solve_lower(chols, Kxs) cov = stabilize(Kss - at.dot(at.transpose(A), A), jitter) return mu, cov
def _build_prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs): mu = self.mean_func(X) cov = stabilize(self.cov_func(X), jitter) if reparameterize: size = np.shape(X)[0] v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs) f = pm.Deterministic(name, mu + cholesky(cov).dot(v), dims=kwargs.get("dims", None)) else: f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs) return f
def _build_conditional(self, Xnew, pred_noise, diag, X, y, noise, cov_total, mean_total): Kxx = cov_total(X) Kxs = self.cov_func(X, Xnew) Knx = noise(X) rxx = y - mean_total(X) L = cholesky(stabilize(Kxx) + Knx) A = solve_lower(L, Kxs) v = solve_lower(L, rxx) mu = self.mean_func(Xnew) + at.dot(at.transpose(A), v) if diag: Kss = self.cov_func(Xnew, diag=True) var = Kss - at.sum(at.square(A), 0) if pred_noise: var += noise(Xnew, diag=True) return mu, var else: Kss = self.cov_func(Xnew) cov = Kss - at.dot(at.transpose(A), A) if pred_noise: cov += noise(Xnew) return mu, cov if pred_noise else stabilize(cov)