Example #1
0
 def forward(self, inputs):
     mean = self.mean_module(inputs)
     covar = self.covar_module(inputs)
     return distributions.MultivariateNormal(mean, covar)
Example #2
0
 def forward(self, x):
     mean_x = self.mean(x)
     covar_x = self.covariance(x)
     return distributions.MultivariateNormal(mean_x, covar_x)
Example #3
0
    def __call__(self, x, y):
        sigma2 = self.gp.likelihood.noise
        z_b = self.gp.variational_strategy.inducing_points
        Kff = self.gp.covar_module(x).evaluate()
        Kbf = self.gp.covar_module(z_b, x).evaluate()
        Kbb = self.gp.covar_module(z_b).add_jitter(self.gp._jitter)

        Q1 = Kbf.transpose(-1, -2) @ Kbb.inv_matmul(Kbf)
        Sigma1 = sigma2 * torch.eye(Q1.size(-1)).to(Q1.device)

        # logp term
        if self.gp._old_strat is None:
            num_data = y.size(-2)
            mean = torch.zeros(num_data).to(y.device)
            covar = (Q1 + Sigma1) + self.gp._jitter * torch.eye(
                Q1.size(-2)).to(Q1.device)
            dist = distributions.MultivariateNormal(mean, covar)
            logp_term = dist.log_prob(y.squeeze(-1)).sum() / y.size(-2)
        else:
            z_a = self.gp._old_strat.inducing_points.detach()
            Kba = self.gp.covar_module(z_b, z_a).evaluate()
            Kaa_old = self.gp._old_kernel(z_a).evaluate().detach()
            Q2 = Kba.transpose(-1, -2) @ Kbb.inv_matmul(Kba)
            zero_1 = torch.zeros(Q1.size(-2), Q2.size(-1)).to(Q1.device)
            zero_2 = torch.zeros(Q2.size(-2), Q1.size(-1)).to(Q1.device)
            Q = torch.cat([
                torch.cat([Q1, zero_1], dim=-1),
                torch.cat([zero_2, Q2], dim=-1)
            ],
                          dim=-2)

            C_old = self.gp._old_C_matrix.detach()
            Sigma2 = Kaa_old @ C_old.inv_matmul(Kaa_old)
            Sigma2 = Sigma2 + self.gp._jitter * torch.eye(Sigma2.size(-2)).to(
                Sigma2.device)
            Sigma = torch.cat([
                torch.cat([Sigma1, zero_1], dim=-1),
                torch.cat([zero_2, Sigma2], dim=-1)
            ],
                              dim=-2)

            y_hat = torch.cat([y, self.gp.pseudotargets])
            mean = torch.zeros_like(y_hat.squeeze(-1))
            covar = (Q + Sigma) + self.gp._jitter * torch.eye(Q.size(-2)).to(
                Q.device)
            dist = distributions.MultivariateNormal(mean, covar)
            logp_term = dist.log_prob(y_hat.squeeze(-1)).sum() / y_hat.size(-2)
            num_data = y_hat.size(-2)

        # trace term
        t1 = (Kff - Q1).diag().sum() / sigma2
        t2 = 0
        if self.gp._old_strat is not None:
            LSigma2 = psd_safe_cholesky(Sigma2,
                                        upper=False,
                                        jitter=self.gp._jitter)
            Kaa = self.gp.covar_module(z_a).evaluate().detach()
            Sigma2_inv_Kaa = torch.cholesky_solve(Kaa, LSigma2, upper=False)
            Sigma2_inv_Q2 = torch.cholesky_solve(Q2, LSigma2, upper=False)
            t2 = Sigma2_inv_Kaa.diag().sum() - Sigma2_inv_Q2.diag().sum()
        trace_term = -(t1 + t2) / 2 / num_data

        if self._combine_terms:
            return logp_term + trace_term
        else:
            return logp_term, trace_term, t1 / num_data, t2 / num_data