예제 #1
0
    def forward(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params):
        # TODO: figure out how to have multiple outputs - i believe method in 
        # def num_ouptuts_per_input -> but also for the kernel model itself
        #print('ldib', last_dim_is_batch)
        if len(x1.shape) is 3:
            x1 = x1[0]
        if x2 is not None:
            if len(x2.shape) is 3:
                x2 = x2[0]

        #print(self.batch_shape, x2 is None)        
        if len(x1.shape) is 2:
            x1 = x1.view(-1, *self.shape)
        

        # TODO: figure out how to have batched dimensional outputs
        if x2 is None:
            kernel = lazify(self.model(x1, diag=diag))
        else:
            x2 = x2.view(-1, *self.shape)
            kernel = lazify(self.model(x1, x2, diag=diag))

        #print('shape of kernel is: ', kernel.shape)
        res = BatchRepeatLazyTensor(kernel, batch_repeat=self.batch_shape)

        if last_dim_is_batch:
            res = res.permute(1, 2, 0)

        return res
    def forward(self,
                mxu1,
                mxu2,
                diag=False,
                last_dim_is_batch=False,
                **params):
        assert not torch.isnan(mxu1).any()
        assert not torch.isnan(mxu2).any()
        M1, X1, U1 = self.decoder.decode(mxu1)
        M2, X2, U2 = self.decoder.decode(mxu2)

        if last_dim_is_batch:
            raise RuntimeError(
                "HetergeneousCoregionalizationKernel does not accept the last_dim_is_batch argument."
            )
        #covar_i = self.mask_dependent_covar(self, M1, U1, M2, U2)
        covar_x = lazify(self.data_covar_module.forward(X1, X2, **params))
        #if torch.rand(1).item() < 0.1:
        #    for name, value in self.data_covar_module.named_parameters():
        #            print(name, value)

        for name, value in self.data_covar_module.named_parameters():
            assert not torch.isnan(value).any()
        res = self.mask_dependent_covar(M1, U1, M2, U2, covar_x)
        return res.diag() if diag else res
예제 #3
0
파일: GPModels.py 프로젝트: LLNL/MTLRecSys
 def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
     if last_dim_is_batch:
         raise RuntimeError(
             "MultitaskKernel does not accept the last_dim_is_batch argument."
         )
     covar_i = self.task_covar_module.covar_matrix
     if len(x1.shape[:-2]):
         covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
     if self.bias_only:
         covar_i = lazify(
             torch.ones_like(covar_i.evaluate())
         )  # task covariance now all one so it shares covariance but still
         # as multitask mean
     covar_x = lazify(self.data_covar_module.forward(x1, x2, **params))
     res = KroneckerProductLazyTensor(covar_x, covar_i)
     return res.diag() if diag else res
예제 #4
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        r"""Computes the posterior over model outputs at the provided points.

        Args:
            X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension
                of the feature space and `q` is the number of points considered
                jointly.
            output_indices: A list of indices, corresponding to the outputs over
                which to compute the posterior (if the model is multi-output).
                Can be used to speed up computation if only a subset of the
                model's outputs are required for optimization. If omitted,
                computes the posterior over all model outputs.
            observation_noise: If True, add the observation noise from the
                likelihood to the posterior. If a Tensor, use it directly as the
                observation noise (must be of shape `(batch_shape) x q x m`).

        Returns:
            A `GPyTorchPosterior` object, representing `batch_shape` joint
            distributions over `q` points and the outputs selected by
            `output_indices` each. Includes observation noise if specified.
        """
        self.eval()  # make sure model is in eval mode
        with gpt_posterior_settings():
            # insert a dimension for the output dimension
            if self._num_outputs > 1:
                X, output_dim_idx = add_output_dim(
                    X=X, original_batch_shape=self._input_batch_shape)
            mvn = self(X)
            if observation_noise is not False:
                if torch.is_tensor(observation_noise):
                    # TODO: Validate noise shape
                    # make observation_noise `batch_shape x q x n`
                    obs_noise = observation_noise.transpose(-1, -2)
                    mvn = self.likelihood(mvn, X, noise=obs_noise)
                elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
                    # Use the mean of the previous noise values (TODO: be smarter here).
                    noise = self.likelihood.noise.mean().expand(X.shape[:-1])
                    mvn = self.likelihood(mvn, X, noise=noise)
                else:
                    mvn = self.likelihood(mvn, X)
            if self._num_outputs > 1:
                mean_x = mvn.mean
                covar_x = mvn.covariance_matrix
                output_indices = output_indices or range(self._num_outputs)
                mvns = [
                    MultivariateNormal(
                        mean_x.select(dim=output_dim_idx, index=t),
                        lazify(covar_x.select(dim=output_dim_idx, index=t)),
                    ) for t in output_indices
                ]
                mvn = MultitaskMultivariateNormal.from_independent_mvns(
                    mvns=mvns)
        return GPyTorchPosterior(mvn=mvn)
예제 #5
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: bool = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        r"""Computes the posterior over model outputs at the provided points.

        Args:
            X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the
                feature space and `q` is the number of points considered jointly.
            output_indices: A list of indices, corresponding to the outputs over
                which to compute the posterior (if the model is multi-output).
                Can be used to speed up computation if only a subset of the
                model's outputs are required for optimization. If omitted,
                computes the posterior over all model outputs.
            observation_noise: If True, add observation noise to the posterior.

        Returns:
            A `GPyTorchPosterior` object, representing `batch_shape` joint
            distributions over `q` points and the outputs selected by
            `output_indices` each. Includes observation noise if
            `observation_noise=True`.
        """
        self.eval()  # make sure model is in eval mode
        with ExitStack() as es:
            es.enter_context(gpt_settings.debug(False))
            es.enter_context(gpt_settings.fast_pred_var())
            es.enter_context(
                gpt_settings.detach_test_caches(
                    settings.propagate_grads.off()))
            # insert a dimension for the output dimension
            if self._num_outputs > 1:
                X, output_dim_idx = add_output_dim(
                    X=X, original_batch_shape=self._input_batch_shape)
            mvn = self(X)
            if observation_noise:
                if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
                    # Use the mean of the previous noise values (TODO: be smarter here).
                    noise = self.likelihood.noise.mean().expand(X.shape[:-1])
                    mvn = self.likelihood(mvn, X, noise=noise)
                else:
                    mvn = self.likelihood(mvn, X)
            if self._num_outputs > 1:
                mean_x = mvn.mean
                covar_x = mvn.covariance_matrix
                output_indices = output_indices or range(self._num_outputs)
                mvns = [
                    MultivariateNormal(
                        mean_x.select(dim=output_dim_idx, index=t),
                        lazify(covar_x.select(dim=output_dim_idx, index=t)),
                    ) for t in output_indices
                ]
                mvn = MultitaskMultivariateNormal.from_independent_mvns(
                    mvns=mvns)
        return GPyTorchPosterior(mvn=mvn)
예제 #6
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: bool = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        r"""Computes the posterior over model outputs at the provided points.

        Args:
            X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the
                feature space and `q` is the number of points considered jointly.
            output_indices: A list of indices, corresponding to the outputs over
                which to compute the posterior (if the model is multi-output).
                Can be used to speed up computation if only a subset of the
                model's outputs are required for optimization. If omitted,
                computes the posterior over all model outputs.
            observation_noise: If True, add observation noise to the posterior.
            propagate_grads: If True, do not detach GPyTorch's test caches when
                computing of the posterior. Required for being able to compute
                derivatives with respect to training inputs at test time (used
                e.g. by qNoisyExpectedImprovement). Defaults to `False`.

        Returns:
            A `GPyTorchPosterior` object, representing `batch_shape` joint
            distributions over `q` points and the outputs selected by
            `output_indices` each. Includes observation noise if
            `observation_noise=True`.
        """
        self.eval()  # make sure model is in eval mode
        detach_test_caches = not kwargs.get("propagate_grads", False)
        with ExitStack() as es:
            es.enter_context(settings.debug(False))
            es.enter_context(settings.fast_pred_var())
            es.enter_context(settings.detach_test_caches(detach_test_caches))
            # insert a dimension for the output dimension
            if self._num_outputs > 1:
                X, output_dim_idx = add_output_dim(
                    X=X, original_batch_shape=self._input_batch_shape
                )
            mvn = self(X)
            if observation_noise:
                mvn = self.likelihood(mvn, X)
            if self._num_outputs > 1:
                mean_x = mvn.mean
                covar_x = mvn.covariance_matrix
                output_indices = output_indices or range(self._num_outputs)
                mvns = [
                    MultivariateNormal(
                        mean_x.select(dim=output_dim_idx, index=t),
                        lazify(covar_x.select(dim=output_dim_idx, index=t)),
                    )
                    for t in output_indices
                ]
                mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
        return GPyTorchPosterior(mvn=mvn)
예제 #7
0
    def test_broadcast_lazy_shape(self):
        test1 = lazify(torch.randn(30, 1))

        test2 = torch.randn(30, 30)
        res = test1 + test2
        final_res = res + test2

        torch_res = res.evaluate() + test2

        self.assertEqual(final_res.shape, torch_res.shape)
        self.assertEqual((final_res.evaluate() - torch_res).sum(), 0.0)
예제 #8
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: bool = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        r"""Computes the posterior over model outputs at the provided points.

        Args:
            X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the
                feature space and `q` is the number of points considered jointly.
            output_indices: A list of indices, corresponding to the outputs over
                which to compute the posterior (if the model is multi-output).
                Can be used to speed up computation if only a subset of the
                model's outputs are required for optimization. If omitted,
                computes the posterior over all model outputs.
            observation_noise: If True, add observation noise to the posterior.
            detach_test_caches: If True, detach GPyTorch test caches during
                computation of the posterior. Required for being able to compute
                derivatives with respect to training inputs at test time (used
                e.g. by qNoisyExpectedImprovement). Defaults to `True`.

        Returns:
            A `GPyTorchPosterior` object, representing `batch_shape` joint
            distributions over `q` points and the outputs selected by
            `output_indices` each. Includes observation noise if
            `observation_noise=True`.
        """
        self.eval()  # make sure model is in eval mode
        detach_test_caches = kwargs.get("detach_test_caches", True)
        with ExitStack() as es:
            es.enter_context(settings.debug(False))
            es.enter_context(settings.fast_pred_var())
            es.enter_context(settings.detach_test_caches(detach_test_caches))
            # insert a dimension for the output dimension
            if self._num_outputs > 1:
                X, output_dim_idx = add_output_dim(
                    X=X, original_batch_shape=self._input_batch_shape
                )
            mvn = self(X)
            mean_x = mvn.mean
            covar_x = mvn.covariance_matrix
            if self._num_outputs > 1:
                output_indices = output_indices or range(self._num_outputs)
                mvns = [
                    MultivariateNormal(
                        mean_x.select(dim=output_dim_idx, index=t),
                        lazify(covar_x.select(dim=output_dim_idx, index=t)),
                    )
                    for t in output_indices
                ]
                mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
        return GPyTorchPosterior(mvn=mvn)
예제 #9
0
    def test_base_sample_shape(self):
        a = torch.randn(5, 10)
        lazy_square_a = RootLazyTensor(lazify(a))
        dist = MultivariateNormal(torch.zeros(5), lazy_square_a)

        # check that providing the base samples is okay
        samples = dist.rsample(torch.Size((16, )),
                               base_samples=torch.randn(16, 10))
        self.assertEqual(samples.shape, torch.Size((16, 5)))

        # check that an event shape of base samples fails
        self.assertRaises(RuntimeError,
                          dist.rsample,
                          torch.Size((16, )),
                          base_samples=torch.randn(16, 5))

        # check that the proper event shape of base samples is okay for
        # a non root lt
        nonlazy_square_a = lazify(lazy_square_a.evaluate())
        dist = MultivariateNormal(torch.zeros(5), nonlazy_square_a)

        samples = dist.rsample(torch.Size((16, )),
                               base_samples=torch.randn(16, 5))
        self.assertEqual(samples.shape, torch.Size((16, 5)))
예제 #10
0
 def test_extract_batch_covar(self):
     tkwargs = {"device": self.device}
     for dtype in (torch.float, torch.double):
         tkwargs["dtype"] = dtype
         base_covar = torch.tensor(
             [[1.0, 0.6, 0.9], [0.6, 1.0, 0.5], [0.9, 0.5, 1.0]], **tkwargs)
         lazy_covar = lazify(
             torch.stack([base_covar, base_covar * 2], dim=0))
         block_diag_covar = BlockDiagLazyTensor(lazy_covar)
         mt_mvn = MultitaskMultivariateNormal(torch.zeros(3, 2, **tkwargs),
                                              block_diag_covar)
         batch_covar = extract_batch_covar(mt_mvn=mt_mvn)
         self.assertTrue(
             torch.equal(batch_covar.evaluate(), lazy_covar.evaluate()))
         # test non BlockDiagLazyTensor
         mt_mvn = MultitaskMultivariateNormal(torch.zeros(3, 2, **tkwargs),
                                              block_diag_covar.evaluate())
         with self.assertRaises(BotorchError):
             extract_batch_covar(mt_mvn=mt_mvn)
예제 #11
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        posterior_transform: Optional[PosteriorTransform] = None,
        **kwargs: Any,
    ) -> MultitaskGPPosterior:
        self.eval()

        if posterior_transform is not None:
            # this could be very costly, disallow for now
            raise NotImplementedError(
                "Posterior transforms currently not supported for "
                f"{self.__class__.__name__}")

        X = self.transform_inputs(X)
        train_x = self.transform_inputs(self.train_inputs[0])

        # construct Ktt
        task_covar = self._task_covar_matrix
        task_rootlt = self._task_covar_matrix.root_decomposition(
            method="diagonalization")
        task_root = task_rootlt.root
        if task_covar.batch_shape != X.shape[:-2]:
            task_covar = BatchRepeatLazyTensor(task_covar,
                                               batch_repeat=X.shape[:-2])
            task_root = BatchRepeatLazyTensor(lazify(task_root),
                                              batch_repeat=X.shape[:-2])

        task_covar_rootlt = RootLazyTensor(task_root)

        # construct RR' \approx Kxx
        data_data_covar = self.train_full_covar.lazy_tensors[0]
        # populate the diagonalziation caches for the root and inverse root
        # decomposition
        data_data_evals, data_data_evecs = data_data_covar.diagonalization()

        # pad the eigenvalue and eigenvectors with zeros if we are using lanczos
        if data_data_evecs.shape[-1] < data_data_evecs.shape[-2]:
            cols_to_add = data_data_evecs.shape[-2] - data_data_evecs.shape[-1]
            zero_evecs = torch.zeros(
                *data_data_evecs.shape[:-1],
                cols_to_add,
                dtype=data_data_evals.dtype,
                device=data_data_evals.device,
            )
            zero_evals = torch.zeros(
                *data_data_evecs.shape[:-2],
                cols_to_add,
                dtype=data_data_evals.dtype,
                device=data_data_evals.device,
            )
            data_data_evecs = CatLazyTensor(
                data_data_evecs,
                lazify(zero_evecs),
                dim=-1,
                output_device=data_data_evals.device,
            )
            data_data_evals = torch.cat((data_data_evals, zero_evals), dim=-1)

        # construct K_{xt, x}
        test_data_covar = self.covar_module.data_covar_module(X, train_x)
        # construct K_{xt, xt}
        test_test_covar = self.covar_module.data_covar_module(X)

        # now update root so that \tilde{R}\tilde{R}' \approx K_{(x,xt), (x,xt)}
        # cloning preserves the gradient history
        updated_lazy_tensor = data_data_covar.cat_rows(
            cross_mat=test_data_covar.clone(),
            new_mat=test_test_covar,
            method="diagonalization",
        )
        updated_root = updated_lazy_tensor.root_decomposition().root
        # occasionally, there's device errors so enforce this comes out right
        updated_root = updated_root.to(data_data_covar.device)

        # build a root decomposition of the joint train/test covariance matrix
        # construct (\tilde{R} \otimes M)(\tilde{R} \otimes M)' \approx
        # (K_{(x,xt), (x,xt)} \otimes Ktt)
        joint_covar = RootLazyTensor(
            KroneckerProductLazyTensor(updated_root,
                                       task_covar_rootlt.root.detach()))

        # construct K_{xt, x} \otimes Ktt
        test_obs_kernel = KroneckerProductLazyTensor(test_data_covar,
                                                     task_covar)

        # collect y - \mu(x) and \mu(X)
        train_diff = self.train_targets - self.mean_module(train_x)
        if detach_test_caches.on():
            train_diff = train_diff.detach()
        test_mean = self.mean_module(X)

        train_noise = self.likelihood._shaped_noise_covar(train_x.shape)
        diagonal_noise = isinstance(train_noise, DiagLazyTensor)
        if detach_test_caches.on():
            train_noise = train_noise.detach()
        test_noise = (self.likelihood._shaped_noise_covar(X.shape)
                      if observation_noise else None)

        # predictive mean and variance for the mvn
        # first the predictive mean
        pred_mean = (test_obs_kernel.matmul(
            self.predictive_mean_cache).reshape_as(test_mean) + test_mean)
        # next the predictive variance, assume diagonal noise
        test_var_term = KroneckerProductLazyTensor(test_test_covar,
                                                   task_covar).diag()

        if diagonal_noise:
            task_evals, task_evecs = self._task_covar_matrix.diagonalization()
            # TODO: make this be the default KPMatmulLT diagonal method in gpytorch
            full_data_inv_evals = (KroneckerProductDiagLazyTensor(
                DiagLazyTensor(data_data_evals), DiagLazyTensor(task_evals)) +
                                   train_noise).inverse()
            test_train_hadamard = KroneckerProductLazyTensor(
                test_data_covar.matmul(data_data_evecs).evaluate()**2,
                task_covar.matmul(task_evecs).evaluate()**2,
            )
            data_var_term = test_train_hadamard.matmul(
                full_data_inv_evals).sum(dim=-1)
        else:
            # if non-diagonal noise (but still kronecker structured), we have to pull
            # across the noise because the inverse is not closed form
            # should be a kronecker lt, R = \Sigma_X^{-1/2} \kron \Sigma_T^{-1/2}
            # TODO: enforce the diagonalization to return a KPLT for all shapes in
            # gpytorch or dense linear algebra for small shapes
            data_noise, task_noise = train_noise.lazy_tensors
            data_noise_root = data_noise.root_inv_decomposition(
                method="diagonalization")
            task_noise_root = task_noise.root_inv_decomposition(
                method="diagonalization")

            # ultimately we need to compute the diagonal of
            # (K_{x* X} \kron K_T)(K_{XX} \kron K_T + \Sigma_X \kron \Sigma_T)^{-1}
            #                           (K_{x* X} \kron K_T)^T
            # = (K_{x* X} \Sigma_X^{-1/2} Q_R)(\Lambda_R + I)^{-1}
            #                       (K_{x* X} \Sigma_X^{-1/2} Q_R)^T
            # where R = (\Sigma_X^{-1/2T}K_{XX}\Sigma_X^{-1/2} \kron
            #                   \Sigma_T^{-1/2T}K_{T}\Sigma_T^{-1/2})
            # first we construct the components of R's eigen-decomposition
            # TODO: make this be the default KPMatmulLT diagonal method in gpytorch
            whitened_data_covar = (data_noise_root.transpose(
                -1, -2).matmul(data_data_covar).matmul(data_noise_root))
            w_data_evals, w_data_evecs = whitened_data_covar.diagonalization()
            whitened_task_covar = (task_noise_root.transpose(-1, -2).matmul(
                self._task_covar_matrix).matmul(task_noise_root))
            w_task_evals, w_task_evecs = whitened_task_covar.diagonalization()

            # we add one to the eigenvalues as above (not just for stability)
            full_data_inv_evals = (KroneckerProductDiagLazyTensor(
                DiagLazyTensor(w_data_evals),
                DiagLazyTensor(w_task_evals)).add_jitter(1.0).inverse())

            test_data_comp = (test_data_covar.matmul(data_noise_root).matmul(
                w_data_evecs).evaluate()**2)
            task_comp = (task_covar.matmul(task_noise_root).matmul(
                w_task_evecs).evaluate()**2)

            test_train_hadamard = KroneckerProductLazyTensor(
                test_data_comp, task_comp)
            data_var_term = test_train_hadamard.matmul(
                full_data_inv_evals).sum(dim=-1)

        pred_variance = test_var_term - data_var_term
        specialized_mvn = MultitaskMultivariateNormal(
            pred_mean, DiagLazyTensor(pred_variance))
        if observation_noise:
            specialized_mvn = self.likelihood(specialized_mvn)

        posterior = MultitaskGPPosterior(
            mvn=specialized_mvn,
            joint_covariance_matrix=joint_covar,
            test_train_covar=test_obs_kernel,
            train_diff=train_diff,
            test_mean=test_mean,
            train_train_covar=self.train_full_covar,
            train_noise=train_noise,
            test_noise=test_noise,
        )

        if hasattr(self, "outcome_transform"):
            posterior = self.outcome_transform.untransform_posterior(posterior)
        return posterior
예제 #12
0
    def forward(self, X, **kwargs):
        if self.training:
            # TODO: return a better dummy here
            # is a dummy b/c the real action happens in the MLL
            if X is not None:
                mean = self.mean_module(X)
                covar = self.covar_module(X)
            else:
                if type(self._batch_shape) is not torch.Size:
                    batch_shape = torch.Size((self._batch_shape, ))
                else:
                    batch_shape = self._batch_shape
                mean_shape = batch_shape + torch.Size((self.num_data, ))
                mean = ZeroLazyTensor(*mean_shape)
                covar_shape = mean_shape + torch.Size((self.num_data, ))
                covar = ZeroLazyTensor(*covar_shape)

            # should hopefuly only occur in batching issues
            if (mean.ndimension() < covar.ndimension()
                    and (self._batch_shape != torch.Size()
                         and mean.shape != covar.shape[:-1])):
                if type(mean) is ZeroLazyTensor:
                    mean = mean.evaluate()
                mean = mean.unsqueeze(0)
                mean = mean.repeat(covar.batch_shape[0],
                                   *[1] * (covar.ndimension() - 1))

            return MultivariateNormal(mean, covar)
        else:
            lazy_kernel = self.covar_module(X).evaluate_kernel()
            pred_mean = left_interp(
                lazy_kernel.left_interp_indices,
                lazy_kernel.left_interp_values,
                self.prediction_cache["pred_mean"],
            )

            if skip_posterior_variances.off():
                # init predictive covariance if it's not in the prediction cache
                if "pred_cov" in self.prediction_cache.keys():
                    inner_pred_cov = self.prediction_cache["pred_cov"]
                else:
                    self.prediction_cache[
                        "pred_cov"] = self._make_predictive_covar()
                    inner_pred_cov = self.prediction_cache["pred_cov"]

                if fast_pred_samples.off():
                    pred_wmat = _get_wmat_from_kernel(lazy_kernel)
                    lazy_pred_wmat = lazify(pred_wmat)
                    pred_cov = lazy_pred_wmat.transpose(-1, -2).matmul(
                        (inner_pred_cov.matmul(lazy_pred_wmat)))

                    if self.has_learnable_noise:
                        pred_cov = pred_cov * self.likelihood.second_noise_covar.noise.to(
                            pred_cov.device)
                else:
                    # inner_pred_cov_root = inner_pred_cov.root_decomposition().root.evaluate()
                    inner_pred_cov_root = inner_pred_cov.root_decomposition(
                        method="lanczos").root.evaluate()
                    if inner_pred_cov_root.shape[-1] > X.shape[-2]:
                        inner_pred_cov_root = inner_pred_cov_root[
                            ..., -X.shape[-2]:]

                    root_tensor = left_interp(
                        lazy_kernel.left_interp_indices,
                        lazy_kernel.left_interp_values,
                        inner_pred_cov_root,
                    )

                    if self.has_learnable_noise:
                        noise_root = self.likelihood.second_noise_covar.noise.to(
                            root_tensor.device)**0.5
                    pred_cov = RootLazyTensor(root_tensor * noise_root)

            else:
                pred_cov = ZeroLazyTensor(*lazy_kernel.size())

            pred_mean = pred_mean[..., 0]
            if self._batch_shape == torch.Size() and X.ndimension() == 2:
                pred_mean = pred_mean[0]
                if pred_cov.ndimension() > 2:
                    pred_cov = pred_cov[0]

            dist = MultivariateNormal(pred_mean, pred_cov)

            return dist