Example #1
0
    def test_batch_left_interp_on_a_batch_matrix(self):
        batch_matrix = torch.randn(2, 6, 3)

        res = left_interp(self.batch_interp_indices, self.batch_interp_values,
                          batch_matrix)
        actual = torch.matmul(self.batch_interp_matrix, batch_matrix)
        self.assertTrue(test._utils.approx_equal(res, actual))
Example #2
0
    def test_batch_left_interp_on_a_vector(self):
        vector = torch.randn(6)

        actual = torch.matmul(self.batch_interp_matrix,
                              vector.unsqueeze(-1).unsqueeze(0)).squeeze(-1)
        res = left_interp(self.batch_interp_indices, self.batch_interp_values,
                          vector)
        self.assertTrue(test._utils.approx_equal(res, actual))
Example #3
0
    def test_batch_left_interp_on_a_matrix(self):
        batch_matrix = torch.randn(6, 3)

        res = left_interp(self.batch_interp_indices, self.batch_interp_values,
                          batch_matrix)
        actual = torch.matmul(self.batch_interp_matrix,
                              batch_matrix.unsqueeze(0))
        self.assertTrue(approx_equal(res, actual))
Example #4
0
    def test_interpolation(self):
        x = torch.linspace(0.01, 1, 100).unsqueeze(1)
        grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1)
        indices, values = Interpolation().interpolate(grid, x)
        indices = indices.squeeze_(0)
        values = values.squeeze_(0)
        test_func_grid = grid.squeeze(1).pow(2)
        test_func_x = x.pow(2).squeeze(-1)

        interp_func_x = left_interp(indices, values,
                                    test_func_grid.unsqueeze(1)).squeeze()

        self.assertTrue(approx_equal(interp_func_x, test_func_x))
Example #5
0
    def test_left_interp_on_a_vector(self):
        vector = torch.randn(6)

        res = left_interp(self.interp_indices, self.interp_values, vector)
        actual = torch.matmul(self.interp_matrix, vector)
        self.assertTrue(test._utils.approx_equal(res, actual))
Example #6
0
    def test_left_interp_on_a_matrix(self):
        matrix = torch.randn(6, 3)

        res = left_interp(self.interp_indices, self.interp_values, matrix)
        actual = torch.matmul(self.interp_matrix, matrix)
        self.assertTrue(approx_equal(res, actual))
Example #7
0
    def forward(self, X, **kwargs):
        if self.training:
            # TODO: return a better dummy here
            # is a dummy b/c the real action happens in the MLL
            if X is not None:
                mean = self.mean_module(X)
                covar = self.covar_module(X)
            else:
                if type(self._batch_shape) is not torch.Size:
                    batch_shape = torch.Size((self._batch_shape, ))
                else:
                    batch_shape = self._batch_shape
                mean_shape = batch_shape + torch.Size((self.num_data, ))
                mean = ZeroLazyTensor(*mean_shape)
                covar_shape = mean_shape + torch.Size((self.num_data, ))
                covar = ZeroLazyTensor(*covar_shape)

            # should hopefuly only occur in batching issues
            if (mean.ndimension() < covar.ndimension()
                    and (self._batch_shape != torch.Size()
                         and mean.shape != covar.shape[:-1])):
                if type(mean) is ZeroLazyTensor:
                    mean = mean.evaluate()
                mean = mean.unsqueeze(0)
                mean = mean.repeat(covar.batch_shape[0],
                                   *[1] * (covar.ndimension() - 1))

            return MultivariateNormal(mean, covar)
        else:
            lazy_kernel = self.covar_module(X).evaluate_kernel()
            pred_mean = left_interp(
                lazy_kernel.left_interp_indices,
                lazy_kernel.left_interp_values,
                self.prediction_cache["pred_mean"],
            )

            if skip_posterior_variances.off():
                # init predictive covariance if it's not in the prediction cache
                if "pred_cov" in self.prediction_cache.keys():
                    inner_pred_cov = self.prediction_cache["pred_cov"]
                else:
                    self.prediction_cache[
                        "pred_cov"] = self._make_predictive_covar()
                    inner_pred_cov = self.prediction_cache["pred_cov"]

                if fast_pred_samples.off():
                    pred_wmat = _get_wmat_from_kernel(lazy_kernel)
                    lazy_pred_wmat = lazify(pred_wmat)
                    pred_cov = lazy_pred_wmat.transpose(-1, -2).matmul(
                        (inner_pred_cov.matmul(lazy_pred_wmat)))

                    if self.has_learnable_noise:
                        pred_cov = pred_cov * self.likelihood.second_noise_covar.noise.to(
                            pred_cov.device)
                else:
                    # inner_pred_cov_root = inner_pred_cov.root_decomposition().root.evaluate()
                    inner_pred_cov_root = inner_pred_cov.root_decomposition(
                        method="lanczos").root.evaluate()
                    if inner_pred_cov_root.shape[-1] > X.shape[-2]:
                        inner_pred_cov_root = inner_pred_cov_root[
                            ..., -X.shape[-2]:]

                    root_tensor = left_interp(
                        lazy_kernel.left_interp_indices,
                        lazy_kernel.left_interp_values,
                        inner_pred_cov_root,
                    )

                    if self.has_learnable_noise:
                        noise_root = self.likelihood.second_noise_covar.noise.to(
                            root_tensor.device)**0.5
                    pred_cov = RootLazyTensor(root_tensor * noise_root)

            else:
                pred_cov = ZeroLazyTensor(*lazy_kernel.size())

            pred_mean = pred_mean[..., 0]
            if self._batch_shape == torch.Size() and X.ndimension() == 2:
                pred_mean = pred_mean[0]
                if pred_cov.ndimension() > 2:
                    pred_cov = pred_cov[0]

            dist = MultivariateNormal(pred_mean, pred_cov)

            return dist