def test_get_item_tensor_index(self): # Tests the default LV.__getitem__ behavior lazy_tensor = ZeroLazyTensor(5, 5) evaluated = lazy_tensor.evaluate() index = (torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (torch.tensor([0, 0, 1, 2]), slice(None, None, None)) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (slice(None, None, None), torch.tensor([0, 0, 1, 2])) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
def test_evaluate(self): lv = ZeroLazyTensor(5, 4, 3) actual = torch.zeros(5, 4, 3) res = lv.evaluate() self.assertLess(torch.norm(res - actual), 1e-4)
def forward(self, X, **kwargs): if self.training: # TODO: return a better dummy here # is a dummy b/c the real action happens in the MLL if X is not None: mean = self.mean_module(X) covar = self.covar_module(X) else: if type(self._batch_shape) is not torch.Size: batch_shape = torch.Size((self._batch_shape, )) else: batch_shape = self._batch_shape mean_shape = batch_shape + torch.Size((self.num_data, )) mean = ZeroLazyTensor(*mean_shape) covar_shape = mean_shape + torch.Size((self.num_data, )) covar = ZeroLazyTensor(*covar_shape) # should hopefuly only occur in batching issues if (mean.ndimension() < covar.ndimension() and (self._batch_shape != torch.Size() and mean.shape != covar.shape[:-1])): if type(mean) is ZeroLazyTensor: mean = mean.evaluate() mean = mean.unsqueeze(0) mean = mean.repeat(covar.batch_shape[0], *[1] * (covar.ndimension() - 1)) return MultivariateNormal(mean, covar) else: lazy_kernel = self.covar_module(X).evaluate_kernel() pred_mean = left_interp( lazy_kernel.left_interp_indices, lazy_kernel.left_interp_values, self.prediction_cache["pred_mean"], ) if skip_posterior_variances.off(): # init predictive covariance if it's not in the prediction cache if "pred_cov" in self.prediction_cache.keys(): inner_pred_cov = self.prediction_cache["pred_cov"] else: self.prediction_cache[ "pred_cov"] = self._make_predictive_covar() inner_pred_cov = self.prediction_cache["pred_cov"] if fast_pred_samples.off(): pred_wmat = _get_wmat_from_kernel(lazy_kernel) lazy_pred_wmat = lazify(pred_wmat) pred_cov = lazy_pred_wmat.transpose(-1, -2).matmul( (inner_pred_cov.matmul(lazy_pred_wmat))) if self.has_learnable_noise: pred_cov = pred_cov * self.likelihood.second_noise_covar.noise.to( pred_cov.device) else: # inner_pred_cov_root = inner_pred_cov.root_decomposition().root.evaluate() inner_pred_cov_root = inner_pred_cov.root_decomposition( method="lanczos").root.evaluate() if inner_pred_cov_root.shape[-1] > X.shape[-2]: inner_pred_cov_root = inner_pred_cov_root[ ..., -X.shape[-2]:] root_tensor = left_interp( lazy_kernel.left_interp_indices, lazy_kernel.left_interp_values, inner_pred_cov_root, ) if self.has_learnable_noise: noise_root = self.likelihood.second_noise_covar.noise.to( root_tensor.device)**0.5 pred_cov = RootLazyTensor(root_tensor * noise_root) else: pred_cov = ZeroLazyTensor(*lazy_kernel.size()) pred_mean = pred_mean[..., 0] if self._batch_shape == torch.Size() and X.ndimension() == 2: pred_mean = pred_mean[0] if pred_cov.ndimension() > 2: pred_cov = pred_cov[0] dist = MultivariateNormal(pred_mean, pred_cov) return dist