def test_batch_left_interp_on_a_batch_matrix(self): batch_matrix = torch.randn(2, 6, 3) res = left_interp(self.batch_interp_indices, self.batch_interp_values, batch_matrix) actual = torch.matmul(self.batch_interp_matrix, batch_matrix) self.assertTrue(test._utils.approx_equal(res, actual))
def test_batch_left_interp_on_a_vector(self): vector = torch.randn(6) actual = torch.matmul(self.batch_interp_matrix, vector.unsqueeze(-1).unsqueeze(0)).squeeze(-1) res = left_interp(self.batch_interp_indices, self.batch_interp_values, vector) self.assertTrue(test._utils.approx_equal(res, actual))
def test_batch_left_interp_on_a_matrix(self): batch_matrix = torch.randn(6, 3) res = left_interp(self.batch_interp_indices, self.batch_interp_values, batch_matrix) actual = torch.matmul(self.batch_interp_matrix, batch_matrix.unsqueeze(0)) self.assertTrue(approx_equal(res, actual))
def test_interpolation(self): x = torch.linspace(0.01, 1, 100).unsqueeze(1) grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1) indices, values = Interpolation().interpolate(grid, x) indices = indices.squeeze_(0) values = values.squeeze_(0) test_func_grid = grid.squeeze(1).pow(2) test_func_x = x.pow(2).squeeze(-1) interp_func_x = left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze() self.assertTrue(approx_equal(interp_func_x, test_func_x))
def test_left_interp_on_a_vector(self): vector = torch.randn(6) res = left_interp(self.interp_indices, self.interp_values, vector) actual = torch.matmul(self.interp_matrix, vector) self.assertTrue(test._utils.approx_equal(res, actual))
def test_left_interp_on_a_matrix(self): matrix = torch.randn(6, 3) res = left_interp(self.interp_indices, self.interp_values, matrix) actual = torch.matmul(self.interp_matrix, matrix) self.assertTrue(approx_equal(res, actual))
def forward(self, X, **kwargs): if self.training: # TODO: return a better dummy here # is a dummy b/c the real action happens in the MLL if X is not None: mean = self.mean_module(X) covar = self.covar_module(X) else: if type(self._batch_shape) is not torch.Size: batch_shape = torch.Size((self._batch_shape, )) else: batch_shape = self._batch_shape mean_shape = batch_shape + torch.Size((self.num_data, )) mean = ZeroLazyTensor(*mean_shape) covar_shape = mean_shape + torch.Size((self.num_data, )) covar = ZeroLazyTensor(*covar_shape) # should hopefuly only occur in batching issues if (mean.ndimension() < covar.ndimension() and (self._batch_shape != torch.Size() and mean.shape != covar.shape[:-1])): if type(mean) is ZeroLazyTensor: mean = mean.evaluate() mean = mean.unsqueeze(0) mean = mean.repeat(covar.batch_shape[0], *[1] * (covar.ndimension() - 1)) return MultivariateNormal(mean, covar) else: lazy_kernel = self.covar_module(X).evaluate_kernel() pred_mean = left_interp( lazy_kernel.left_interp_indices, lazy_kernel.left_interp_values, self.prediction_cache["pred_mean"], ) if skip_posterior_variances.off(): # init predictive covariance if it's not in the prediction cache if "pred_cov" in self.prediction_cache.keys(): inner_pred_cov = self.prediction_cache["pred_cov"] else: self.prediction_cache[ "pred_cov"] = self._make_predictive_covar() inner_pred_cov = self.prediction_cache["pred_cov"] if fast_pred_samples.off(): pred_wmat = _get_wmat_from_kernel(lazy_kernel) lazy_pred_wmat = lazify(pred_wmat) pred_cov = lazy_pred_wmat.transpose(-1, -2).matmul( (inner_pred_cov.matmul(lazy_pred_wmat))) if self.has_learnable_noise: pred_cov = pred_cov * self.likelihood.second_noise_covar.noise.to( pred_cov.device) else: # inner_pred_cov_root = inner_pred_cov.root_decomposition().root.evaluate() inner_pred_cov_root = inner_pred_cov.root_decomposition( method="lanczos").root.evaluate() if inner_pred_cov_root.shape[-1] > X.shape[-2]: inner_pred_cov_root = inner_pred_cov_root[ ..., -X.shape[-2]:] root_tensor = left_interp( lazy_kernel.left_interp_indices, lazy_kernel.left_interp_values, inner_pred_cov_root, ) if self.has_learnable_noise: noise_root = self.likelihood.second_noise_covar.noise.to( root_tensor.device)**0.5 pred_cov = RootLazyTensor(root_tensor * noise_root) else: pred_cov = ZeroLazyTensor(*lazy_kernel.size()) pred_mean = pred_mean[..., 0] if self._batch_shape == torch.Size() and X.ndimension() == 2: pred_mean = pred_mean[0] if pred_cov.ndimension() > 2: pred_cov = pred_cov[0] dist = MultivariateNormal(pred_mean, pred_cov) return dist