def test_get_item_tensor_index(self):
        # Tests the default LV.__getitem__ behavior
        lazy_tensor = ZeroLazyTensor(5, 5)
        evaluated = lazy_tensor.evaluate()

        index = (torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (torch.tensor([0, 0, 1, 2]), slice(None, None, None))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (slice(None, None, None), torch.tensor([0, 0, 1, 2]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
 def test_evaluate(self):
     lv = ZeroLazyTensor(5, 4, 3)
     actual = torch.zeros(5, 4, 3)
     res = lv.evaluate()
     self.assertLess(torch.norm(res - actual), 1e-4)
Exemple #3
0
    def forward(self, X, **kwargs):
        if self.training:
            # TODO: return a better dummy here
            # is a dummy b/c the real action happens in the MLL
            if X is not None:
                mean = self.mean_module(X)
                covar = self.covar_module(X)
            else:
                if type(self._batch_shape) is not torch.Size:
                    batch_shape = torch.Size((self._batch_shape, ))
                else:
                    batch_shape = self._batch_shape
                mean_shape = batch_shape + torch.Size((self.num_data, ))
                mean = ZeroLazyTensor(*mean_shape)
                covar_shape = mean_shape + torch.Size((self.num_data, ))
                covar = ZeroLazyTensor(*covar_shape)

            # should hopefuly only occur in batching issues
            if (mean.ndimension() < covar.ndimension()
                    and (self._batch_shape != torch.Size()
                         and mean.shape != covar.shape[:-1])):
                if type(mean) is ZeroLazyTensor:
                    mean = mean.evaluate()
                mean = mean.unsqueeze(0)
                mean = mean.repeat(covar.batch_shape[0],
                                   *[1] * (covar.ndimension() - 1))

            return MultivariateNormal(mean, covar)
        else:
            lazy_kernel = self.covar_module(X).evaluate_kernel()
            pred_mean = left_interp(
                lazy_kernel.left_interp_indices,
                lazy_kernel.left_interp_values,
                self.prediction_cache["pred_mean"],
            )

            if skip_posterior_variances.off():
                # init predictive covariance if it's not in the prediction cache
                if "pred_cov" in self.prediction_cache.keys():
                    inner_pred_cov = self.prediction_cache["pred_cov"]
                else:
                    self.prediction_cache[
                        "pred_cov"] = self._make_predictive_covar()
                    inner_pred_cov = self.prediction_cache["pred_cov"]

                if fast_pred_samples.off():
                    pred_wmat = _get_wmat_from_kernel(lazy_kernel)
                    lazy_pred_wmat = lazify(pred_wmat)
                    pred_cov = lazy_pred_wmat.transpose(-1, -2).matmul(
                        (inner_pred_cov.matmul(lazy_pred_wmat)))

                    if self.has_learnable_noise:
                        pred_cov = pred_cov * self.likelihood.second_noise_covar.noise.to(
                            pred_cov.device)
                else:
                    # inner_pred_cov_root = inner_pred_cov.root_decomposition().root.evaluate()
                    inner_pred_cov_root = inner_pred_cov.root_decomposition(
                        method="lanczos").root.evaluate()
                    if inner_pred_cov_root.shape[-1] > X.shape[-2]:
                        inner_pred_cov_root = inner_pred_cov_root[
                            ..., -X.shape[-2]:]

                    root_tensor = left_interp(
                        lazy_kernel.left_interp_indices,
                        lazy_kernel.left_interp_values,
                        inner_pred_cov_root,
                    )

                    if self.has_learnable_noise:
                        noise_root = self.likelihood.second_noise_covar.noise.to(
                            root_tensor.device)**0.5
                    pred_cov = RootLazyTensor(root_tensor * noise_root)

            else:
                pred_cov = ZeroLazyTensor(*lazy_kernel.size())

            pred_mean = pred_mean[..., 0]
            if self._batch_shape == torch.Size() and X.ndimension() == 2:
                pred_mean = pred_mean[0]
                if pred_cov.ndimension() > 2:
                    pred_cov = pred_cov[0]

            dist = MultivariateNormal(pred_mean, pred_cov)

            return dist