def test_get_item_tensor_index(self):
        # Tests the default LV.__getitem__ behavior
        lazy_tensor = ZeroLazyTensor(5, 5)
        evaluated = lazy_tensor.evaluate()

        index = (torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (torch.tensor([0, 0, 1, 2]), slice(None, None, None))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (slice(None, None, None), torch.tensor([0, 0, 1, 2]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
Exemple #2
0
    def forward(self, x1, x2, diag=False, are_equal=True, **params):
        batch_shape = self.batch_shape
        leading_dim = x1.size()[:-2]

        if self.constant_noise:
            _are_equal = (x1.shape == x2.shape)
        else:
            _are_equal = torch.equal(x1, x2) and are_equal

        if _are_equal:
            noise_var = torch.exp(self.noise_log_var).expand(-1, x1.size(-2))
            K = DiagLazyTensor(noise_var)
        else:
            K = ZeroLazyTensor(*leading_dim,
                               x1.size(-2),
                               x2.size(-2),
                               dtype=x1.dtype,
                               device=x1.device)

        if diag:
            K = K.diag()
            if not leading_dim:
                K = K.unsqueeze(0)
            return K  # return torch.tensor rather than lazy. Consistent with other's kernels behavior

        return K
Exemple #3
0
    def exact_predictive_covar(self, test_test_covar, test_train_covar):
        """
        Computes the posterior predictive covariance of a GP
        Args:
            test_train_covar (:obj:`gpytorch.lazy.LazyTensor`): Covariance matrix between test and train inputs
            test_test_covar (:obj:`gpytorch.lazy.LazyTensor`): Covariance matrix between test inputs
        Returns:
            :obj:`gpytorch.lazy.LazyTensor`: A LazyTensor representing the predictive posterior covariance of the
                                               test points
        """
        if settings.fast_pred_var.on():
            self._last_test_train_covar = test_train_covar

        if settings.skip_posterior_variances.on():
            return ZeroLazyTensor(*test_test_covar.size())

        if settings.fast_pred_var.off():
            super().exact_predictive_covar(test_test_covar, test_train_covar)
        else:
            features_xstar = test_train_covar.evaluate_kernel().get_root(
                dim=-2)

            # compute J^T Cache as our root tensor
            j_star_covar = features_xstar.t() @ self.covar_cache

            covar_expanded = RootLazyTensor(j_star_covar)
            return self.noise * covar_expanded
    def test_add_diag(self):
        diag = torch.tensor(1.5)
        res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5])
        res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5, 1.3, 1.2, 1.1, 2.])
        res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate()
        actual = diag.diag()
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor(1.5)
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).unsqueeze(0).repeat(2, 1, 1).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5])
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).unsqueeze(0).repeat(2, 1, 1).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5, 1.3, 1.2, 1.1, 2.])
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = diag.diag().unsqueeze(0).repeat(2, 1, 1)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([[1.5, 1.3, 1.2, 1.1, 2.], [0, 1, 2, 1, 1]])
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = torch.cat([diag[0].diag().unsqueeze(0), diag[1].diag().unsqueeze(0)])
        self.assertTrue(approx_equal(res, actual))
    def test_getitem_ellipsis(self):
        lv = ZeroLazyTensor(5, 4, 3)

        res_one = lv[[0, 1]].evaluate()
        self.assertLess(torch.norm(res_one - torch.zeros(2, 4, 3)), 1e-4)
        res_two = lv[:, [0, 1], ...].evaluate()
        self.assertLess(torch.norm(res_two - torch.zeros(5, 2, 3)), 1e-4)
        res_three = lv[..., [0, 2]].evaluate()
        self.assertLess(torch.norm(res_three - torch.zeros(5, 4, 2)), 1e-4)
    def test_getitem(self):
        lv = ZeroLazyTensor(5, 4, 3)

        res_one = lv[0].evaluate()
        self.assertLess(torch.norm(res_one - torch.zeros(4, 3)), 1e-4)
        res_two = lv[:, 1, :]
        self.assertLess(torch.norm(res_two - torch.zeros(5, 3)), 1e-4)
        res_three = lv[:, :, 2]
        self.assertLess(torch.norm(res_three - torch.zeros(5, 4)), 1e-4)
Exemple #7
0
    def test_getitem_complex(self):
        lv = ZeroLazyTensor(5, 4, 3)

        res_one = lv[[0, 1]].evaluate()
        res_two = lv[:, [0, 1], :]
        res_three = lv[:, :, [0, 2]]

        self.assertLess(torch.norm(res_one - torch.zeros(2, 4, 3)), 1e-4)
        self.assertLess(torch.norm(res_two - torch.zeros(5, 2, 3)), 1e-4)
        self.assertLess(torch.norm(res_three - torch.zeros(5, 4, 2)), 1e-4)
Exemple #8
0
 def forward(self, x1, x2):
     if self.training and torch.equal(x1, x2):
         # Reshape into a batch of batch_size diagonal matrices, each of which is
         # (data_size * task_size) x (data_size * task_size)
         return DiagLazyTensor(
             self.variances.view(self.variances.size(0), -1))
     elif x1.size(-2) == x2.size(-2) and x1.size(-2) == self.variances.size(
             1) and torch.equal(x1, x2):
         return DiagLazyTensor(
             self.variances.view(self.variances.size(0), -1))
     else:
         return ZeroLazyTensor(x1.size(-3), x1.size(-2), x2.size(-2))
Exemple #9
0
    def test_matmul(self):
        zero = ZeroLazyTensor(5, 4, 3)
        lazy_square = ZeroLazyTensor(5, 3, 3)
        actual = torch.zeros(5, 4, 3)
        product = zero.matmul(lazy_square)
        self.assertTrue(approx_equal(product, actual))

        tensor_square = torch.eye(3, dtype=int).repeat(5, 1, 1)
        product = zero._matmul(tensor_square)
        self.assertTrue(approx_equal(product, actual))
        self.assertEqual(product.dtype, tensor_square.dtype)

        tensor_square = torch.eye(4).repeat(5, 1, 1)
        actual = torch.zeros(5, 3, 4)
        product = zero._t_matmul(tensor_square)
        self.assertTrue(approx_equal(product, actual))
Exemple #10
0
    def forward(self, x1, x2):
        res = ZeroLazyTensor()
        for kern in self.kernels:
            res = res + kern(x1, x2).evaluate_kernel()

        return res
 def test_evaluate(self):
     lv = ZeroLazyTensor(5, 4, 3)
     actual = torch.zeros(5, 4, 3)
     res = lv.evaluate()
     self.assertLess(torch.norm(res - actual), 1e-4)
Exemple #12
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        self.eval()  # make sure we're calling a posterior

        no_pred_variance = skip_posterior_variances._state

        with ExitStack() as es:
            es.enter_context(gpt_posterior_settings())
            es.enter_context(fast_pred_var(True))

            # we need to skip posterior variances here
            es.enter_context(skip_posterior_variances(True))
            mvn = self(X)
            if observation_noise is not False:
                # TODO: implement Kronecker + diagonal solves so that this is possible.
                # if torch.is_tensor(observation_noise):
                #     # TODO: Validate noise shape
                #     # make observation_noise `batch_shape x q x n`
                #     obs_noise = observation_noise.transpose(-1, -2)
                #     mvn = self.likelihood(mvn, X, noise=obs_noise)
                # elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
                #     noise = self.likelihood.noise.mean().expand(X.shape[:-1])
                #     mvn = self.likelihood(mvn, X, noise=noise)
                # else:
                mvn = self.likelihood(mvn, X)

            # lazy covariance matrix includes the interpolated version of the full
            # covariance matrix so we can actually grab that instead.
            if X.ndimension() > self.train_inputs[0].ndimension():
                X_batch_shape = X.shape[:-2]
                train_inputs = self.train_inputs[0].reshape(
                    *[1] * len(X_batch_shape), *self.train_inputs[0].shape
                )
                train_inputs = train_inputs.repeat(
                    *X_batch_shape, *[1] * self.train_inputs[0].ndimension()
                )
            else:
                train_inputs = self.train_inputs[0]
            full_covar = self.covar_modules[0](torch.cat((train_inputs, X), dim=-2))

            if no_pred_variance:
                pred_variance = mvn.variance
            else:
                joint_covar = self._get_joint_covariance([X])
                pred_variance = self.make_posterior_variances(joint_covar)

                full_covar = KroneckerProductLazyTensor(
                    full_covar, *joint_covar.lazy_tensors[1:]
                )

            joint_covar_list = [self.covar_modules[0](X, train_inputs)]
            batch_shape = joint_covar_list[0].batch_shape
            for cm, param in zip(self.covar_modules[1:], self.latent_parameters):
                covar = cm(param)
                if covar.batch_shape != batch_shape:
                    covar = BatchRepeatLazyTensor(covar, batch_shape)
                joint_covar_list.append(covar)

            test_train_covar = KroneckerProductLazyTensor(*joint_covar_list)

            # mean and variance get reshaped into the target shape
            new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape)
            if not no_pred_variance:
                new_variance = pred_variance.reshape(*X.shape[:-1], *self.target_shape)
                new_variance = DiagLazyTensor(new_variance)
            else:
                new_variance = ZeroLazyTensor(
                    *X.shape[:-1], *self.target_shape, self.target_shape[-1]
                )

            mvn = MultivariateNormal(new_mean, new_variance)

            # return a specialized Posterior to allow for sampling
            posterior = HigherOrderGPPosterior(
                mvn=mvn,
                train_targets=self.train_targets.unsqueeze(-1),
                train_train_covar=self.prediction_strategy.lik_train_train_covar,
                test_train_covar=test_train_covar,
                joint_covariance_matrix=full_covar,
                output_shape=Size(
                    (
                        *X.shape[:-1],
                        *self.target_shape,
                    )
                ),
                num_outputs=self._num_outputs,
            )
            if hasattr(self, "outcome_transform"):
                posterior = self.outcome_transform.untransform_posterior(posterior)

            return posterior
Exemple #13
0
    def forward(self, x):
        """Forward propagate the module.

        This method determines how to marginalize out the inducing function values.
        Specifically, forward defines how to transform a variational distribution over
        the inducing point values, q(u), in to a variational distribution over
        the function values at specified locations x, q(f|x), by integrating
        p(f|x, u)q(u)du

        Parameters
        ----------
        x (torch.tensor):
            Locations x to get the variational posterior of the function values at.

        Returns
        -------
            The distribution q(f|x)
        """
        variational_dist = self.variational_distribution.approx_variational_distribution
        inducing_points = self.inducing_points
        inducing_batch_shape = inducing_points.shape[:-2]
        if inducing_batch_shape < x.shape[:-2] or len(
                inducing_batch_shape) < len(x.shape[:-2]):
            batch_shape = _mul_broadcast_shape(inducing_points.shape[:-2],
                                               x.shape[:-2])
            inducing_points = inducing_points.expand(
                *batch_shape, *inducing_points.shape[-2:])
            x = x.expand(*batch_shape, *x.shape[-2:])
            variational_dist = variational_dist.expand(batch_shape)

        # If our points equal the inducing points, we're done
        if torch.equal(x, inducing_points):
            return variational_dist

        # Otherwise, we have to marginalize
        else:
            num_induc = inducing_points.size(-2)
            full_inputs = torch.cat([inducing_points, x], dim=-2)
            full_output = self.model.forward(full_inputs)
            full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix

            # Mean terms
            test_mean = full_mean[..., num_induc:]
            induc_mean = full_mean[..., :num_induc]
            mean_diff = (variational_dist.mean - induc_mean).unsqueeze(-1)

            # Covariance terms
            induc_induc_covar = full_covar[
                ..., :num_induc, :num_induc].add_jitter()
            induc_data_covar = full_covar[..., :num_induc,
                                          num_induc:].evaluate()
            data_data_covar = full_covar[..., num_induc:, num_induc:]
            aux = variational_dist.lazy_covariance_matrix.root_decomposition()
            root_variational_covar = aux.root.evaluate()

            # If we had to expand the inducing points,
            # shrink the inducing mean and induc_induc_covar dimension
            # This makes everything more computationally efficient
            if len(inducing_batch_shape) < len(induc_induc_covar.batch_shape):
                index = tuple(0 for _ in range(
                    len(induc_induc_covar.batch_shape) -
                    len(inducing_batch_shape)))
                repeat_size = torch.Size(
                    (tuple(induc_induc_covar.batch_shape[:len(index)]) + tuple(
                        1
                        for _ in induc_induc_covar.batch_shape[len(index):])))
                induc_induc_covar = BatchRepeatLazyTensor(
                    induc_induc_covar.__getitem__(index), repeat_size)

            # If we're less than a certain size, we'll compute the Cholesky
            # decomposition of induc_induc_covar
            cholesky = False
            if settings.fast_computations.log_prob.off() or (
                    num_induc <= settings.max_cholesky_size.value()):
                induc_induc_covar = CholLazyTensor(
                    induc_induc_covar.cholesky())
                cholesky = True

            # If we are making predictions and don't need variances, we can do things
            # very quickly.
            if not self.training and settings.skip_posterior_variances.on():
                if not hasattr(self, "_mean_cache"):
                    self._mean_cache = induc_induc_covar.inv_matmul(
                        mean_diff).detach()

                predictive_mean = torch.add(
                    test_mean,
                    induc_data_covar.transpose(-2, -1).matmul(
                        self._mean_cache).squeeze(-1))

                predictive_covar = ZeroLazyTensor(test_mean.size(-1),
                                                  test_mean.size(-1))

                return MultivariateNormal(predictive_mean, predictive_covar)

            # Cache the CG results
            # For now: run variational inference without a preconditioner
            # The preconditioner screws things up for some reason
            with settings.max_preconditioner_size(0):
                # Cache the CG results
                left_tensors = torch.cat([mean_diff, root_variational_covar],
                                         -1)
                with torch.no_grad():
                    eager_rhs = torch.cat([left_tensors, induc_data_covar], -1)
                    solve, probe_vecs, probe_vec_norms, probe_vec_solves, tmats = \
                        CachedCGLazyTensor.precompute_terms(
                            induc_induc_covar, eager_rhs.detach(),
                            logdet_terms=(not cholesky),
                            include_tmats=(not settings.skip_logdet_forward.on() and
                                           not cholesky)
                        )
                    eager_rhss = [
                        eager_rhs.detach(),
                        eager_rhs[..., left_tensors.size(-1):].detach(),
                        eager_rhs[..., :left_tensors.size(-1)].detach()
                    ]
                    solves = [
                        solve.detach(), solve[...,
                                              left_tensors.size(-1):].detach(),
                        solve[..., :left_tensors.size(-1)].detach()
                    ]
                    if settings.skip_logdet_forward.on():
                        eager_rhss.append(
                            torch.cat([probe_vecs, left_tensors], -1))
                        solves.append(
                            torch.cat([
                                probe_vec_solves,
                                solve[..., :left_tensors.size(-1)]
                            ], -1))
                induc_induc_covar = CachedCGLazyTensor(
                    induc_induc_covar,
                    eager_rhss=eager_rhss,
                    solves=solves,
                    probe_vectors=probe_vecs,
                    probe_vector_norms=probe_vec_norms,
                    probe_vector_solves=probe_vec_solves,
                    probe_vector_tmats=tmats,
                )

            if self.training:
                self._memoize_cache[
                    "prior_distribution_memo"] = MultivariateNormal(
                        induc_mean, induc_induc_covar)

            # Compute predictive mean/covariance
            inv_products = induc_induc_covar.inv_matmul(
                induc_data_covar, left_tensors.transpose(-1, -2))
            predictive_mean = torch.add(test_mean, inv_products[..., 0, :])
            predictive_covar = RootLazyTensor(inv_products[...,
                                                           1:, :].transpose(
                                                               -1, -2))
            if self.training:
                interp_data_data_var, _ = induc_induc_covar.inv_quad_logdet(
                    induc_data_covar, logdet=False, reduce_inv_quad=False)
                data_covariance = DiagLazyTensor(
                    (data_data_covar.diag() - interp_data_data_var).clamp(
                        0, math.inf))
            else:
                neg_induc_data_data_covar = torch.matmul(
                    induc_data_covar.transpose(-1, -2).mul(-1),
                    induc_induc_covar.inv_matmul(induc_data_covar))
                data_covariance = data_data_covar + neg_induc_data_data_covar
            predictive_covar = PsdSumLazyTensor(predictive_covar,
                                                data_covariance)

            return MultivariateNormal(predictive_mean, predictive_covar)
Exemple #14
0
    def forward(self, X, **kwargs):
        if self.training:
            # TODO: return a better dummy here
            # is a dummy b/c the real action happens in the MLL
            if X is not None:
                mean = self.mean_module(X)
                covar = self.covar_module(X)
            else:
                if type(self._batch_shape) is not torch.Size:
                    batch_shape = torch.Size((self._batch_shape, ))
                else:
                    batch_shape = self._batch_shape
                mean_shape = batch_shape + torch.Size((self.num_data, ))
                mean = ZeroLazyTensor(*mean_shape)
                covar_shape = mean_shape + torch.Size((self.num_data, ))
                covar = ZeroLazyTensor(*covar_shape)

            # should hopefuly only occur in batching issues
            if (mean.ndimension() < covar.ndimension()
                    and (self._batch_shape != torch.Size()
                         and mean.shape != covar.shape[:-1])):
                if type(mean) is ZeroLazyTensor:
                    mean = mean.evaluate()
                mean = mean.unsqueeze(0)
                mean = mean.repeat(covar.batch_shape[0],
                                   *[1] * (covar.ndimension() - 1))

            return MultivariateNormal(mean, covar)
        else:
            lazy_kernel = self.covar_module(X).evaluate_kernel()
            pred_mean = left_interp(
                lazy_kernel.left_interp_indices,
                lazy_kernel.left_interp_values,
                self.prediction_cache["pred_mean"],
            )

            if skip_posterior_variances.off():
                # init predictive covariance if it's not in the prediction cache
                if "pred_cov" in self.prediction_cache.keys():
                    inner_pred_cov = self.prediction_cache["pred_cov"]
                else:
                    self.prediction_cache[
                        "pred_cov"] = self._make_predictive_covar()
                    inner_pred_cov = self.prediction_cache["pred_cov"]

                if fast_pred_samples.off():
                    pred_wmat = _get_wmat_from_kernel(lazy_kernel)
                    lazy_pred_wmat = lazify(pred_wmat)
                    pred_cov = lazy_pred_wmat.transpose(-1, -2).matmul(
                        (inner_pred_cov.matmul(lazy_pred_wmat)))

                    if self.has_learnable_noise:
                        pred_cov = pred_cov * self.likelihood.second_noise_covar.noise.to(
                            pred_cov.device)
                else:
                    # inner_pred_cov_root = inner_pred_cov.root_decomposition().root.evaluate()
                    inner_pred_cov_root = inner_pred_cov.root_decomposition(
                        method="lanczos").root.evaluate()
                    if inner_pred_cov_root.shape[-1] > X.shape[-2]:
                        inner_pred_cov_root = inner_pred_cov_root[
                            ..., -X.shape[-2]:]

                    root_tensor = left_interp(
                        lazy_kernel.left_interp_indices,
                        lazy_kernel.left_interp_values,
                        inner_pred_cov_root,
                    )

                    if self.has_learnable_noise:
                        noise_root = self.likelihood.second_noise_covar.noise.to(
                            root_tensor.device)**0.5
                    pred_cov = RootLazyTensor(root_tensor * noise_root)

            else:
                pred_cov = ZeroLazyTensor(*lazy_kernel.size())

            pred_mean = pred_mean[..., 0]
            if self._batch_shape == torch.Size() and X.ndimension() == 2:
                pred_mean = pred_mean[0]
                if pred_cov.ndimension() > 2:
                    pred_cov = pred_cov[0]

            dist = MultivariateNormal(pred_mean, pred_cov)

            return dist
Exemple #15
0
 def test_representation(self):
     lv = ZeroLazyTensor(5, 4, 3)
     representation = lv.representation()
     self.assertIsInstance(representation, torch.Tensor)
Exemple #16
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        self.eval()  # make sure we're calling a posterior
        # input transforms are applied at `posterior` in `eval` mode, and at
        # `model.forward()` at the training time
        X = self.transform_inputs(X)
        no_pred_variance = skip_posterior_variances._state

        with ExitStack() as es:
            es.enter_context(gpt_posterior_settings())
            es.enter_context(fast_pred_var(True))

            # we need to skip posterior variances here
            es.enter_context(skip_posterior_variances(True))
            mvn = self(X)
            if observation_noise is not False:
                # TODO: ensure that this still works for structured noise solves.
                mvn = self.likelihood(mvn, X)

            # lazy covariance matrix includes the interpolated version of the full
            # covariance matrix so we can actually grab that instead.
            if X.ndimension() > self.train_inputs[0].ndimension():
                X_batch_shape = X.shape[:-2]
                train_inputs = self.train_inputs[0].reshape(
                    *[1] * len(X_batch_shape), *self.train_inputs[0].shape
                )
                train_inputs = train_inputs.repeat(
                    *X_batch_shape, *[1] * self.train_inputs[0].ndimension()
                )
            else:
                train_inputs = self.train_inputs[0]

            # we now compute the data covariances for the training data, the testing
            # data, the joint covariances, and the test train cross-covariance
            train_train_covar = self.prediction_strategy.lik_train_train_covar.detach()
            base_train_train_covar = train_train_covar.lazy_tensor

            data_train_covar = base_train_train_covar.lazy_tensors[0]
            data_covar = self.covar_modules[0]
            data_train_test_covar = data_covar(X, train_inputs)
            data_test_test_covar = data_covar(X)
            data_joint_covar = data_train_covar.cat_rows(
                cross_mat=data_train_test_covar,
                new_mat=data_test_test_covar,
            )

            # we detach the latents so that they don't cause gradient errors
            # TODO: Can we enable backprop through the latent covariances?
            batch_shape = data_train_test_covar.batch_shape
            latent_covar_list = []
            for latent_covar in base_train_train_covar.lazy_tensors[1:]:
                if latent_covar.batch_shape != batch_shape:
                    latent_covar = BatchRepeatLazyTensor(latent_covar, batch_shape)
                latent_covar_list.append(latent_covar.detach())

            joint_covar = KroneckerProductLazyTensor(
                data_joint_covar, *latent_covar_list
            )
            test_train_covar = KroneckerProductLazyTensor(
                data_train_test_covar, *latent_covar_list
            )

            # compute the posterior variance if necessary
            if no_pred_variance:
                pred_variance = mvn.variance
            else:
                pred_variance = self.make_posterior_variances(joint_covar)

            # mean and variance get reshaped into the target shape
            new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape)
            if not no_pred_variance:
                new_variance = pred_variance.reshape(*X.shape[:-1], *self.target_shape)
                new_variance = DiagLazyTensor(new_variance)
            else:
                new_variance = ZeroLazyTensor(
                    *X.shape[:-1], *self.target_shape, self.target_shape[-1]
                )

            mvn = MultivariateNormal(new_mean, new_variance)

            # return a specialized Posterior to allow for sampling
            # cloning the full covar allows backpropagation through it
            posterior = HigherOrderGPPosterior(
                mvn=mvn,
                train_targets=self.train_targets.unsqueeze(-1),
                train_train_covar=train_train_covar,
                test_train_covar=test_train_covar,
                joint_covariance_matrix=joint_covar.clone(),
                output_shape=X.shape[:-1] + self.target_shape,
                num_outputs=self._num_outputs,
            )
            if hasattr(self, "outcome_transform"):
                posterior = self.outcome_transform.untransform_posterior(posterior)

            return posterior