def forward(self, inputs): mean = self.mean_module(inputs) covar = self.covar_module(inputs) return distributions.MultivariateNormal(mean, covar)
def forward(self, x): mean_x = self.mean(x) covar_x = self.covariance(x) return distributions.MultivariateNormal(mean_x, covar_x)
def __call__(self, x, y): sigma2 = self.gp.likelihood.noise z_b = self.gp.variational_strategy.inducing_points Kff = self.gp.covar_module(x).evaluate() Kbf = self.gp.covar_module(z_b, x).evaluate() Kbb = self.gp.covar_module(z_b).add_jitter(self.gp._jitter) Q1 = Kbf.transpose(-1, -2) @ Kbb.inv_matmul(Kbf) Sigma1 = sigma2 * torch.eye(Q1.size(-1)).to(Q1.device) # logp term if self.gp._old_strat is None: num_data = y.size(-2) mean = torch.zeros(num_data).to(y.device) covar = (Q1 + Sigma1) + self.gp._jitter * torch.eye( Q1.size(-2)).to(Q1.device) dist = distributions.MultivariateNormal(mean, covar) logp_term = dist.log_prob(y.squeeze(-1)).sum() / y.size(-2) else: z_a = self.gp._old_strat.inducing_points.detach() Kba = self.gp.covar_module(z_b, z_a).evaluate() Kaa_old = self.gp._old_kernel(z_a).evaluate().detach() Q2 = Kba.transpose(-1, -2) @ Kbb.inv_matmul(Kba) zero_1 = torch.zeros(Q1.size(-2), Q2.size(-1)).to(Q1.device) zero_2 = torch.zeros(Q2.size(-2), Q1.size(-1)).to(Q1.device) Q = torch.cat([ torch.cat([Q1, zero_1], dim=-1), torch.cat([zero_2, Q2], dim=-1) ], dim=-2) C_old = self.gp._old_C_matrix.detach() Sigma2 = Kaa_old @ C_old.inv_matmul(Kaa_old) Sigma2 = Sigma2 + self.gp._jitter * torch.eye(Sigma2.size(-2)).to( Sigma2.device) Sigma = torch.cat([ torch.cat([Sigma1, zero_1], dim=-1), torch.cat([zero_2, Sigma2], dim=-1) ], dim=-2) y_hat = torch.cat([y, self.gp.pseudotargets]) mean = torch.zeros_like(y_hat.squeeze(-1)) covar = (Q + Sigma) + self.gp._jitter * torch.eye(Q.size(-2)).to( Q.device) dist = distributions.MultivariateNormal(mean, covar) logp_term = dist.log_prob(y_hat.squeeze(-1)).sum() / y_hat.size(-2) num_data = y_hat.size(-2) # trace term t1 = (Kff - Q1).diag().sum() / sigma2 t2 = 0 if self.gp._old_strat is not None: LSigma2 = psd_safe_cholesky(Sigma2, upper=False, jitter=self.gp._jitter) Kaa = self.gp.covar_module(z_a).evaluate().detach() Sigma2_inv_Kaa = torch.cholesky_solve(Kaa, LSigma2, upper=False) Sigma2_inv_Q2 = torch.cholesky_solve(Q2, LSigma2, upper=False) t2 = Sigma2_inv_Kaa.diag().sum() - Sigma2_inv_Q2.diag().sum() trace_term = -(t1 + t2) / 2 / num_data if self._combine_terms: return logp_term + trace_term else: return logp_term, trace_term, t1 / num_data, t2 / num_data