def posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: Union[bool, Tensor] = False, **kwargs: Any, ) -> GPyTorchPosterior: self.eval() # make sure we're calling a posterior no_pred_variance = skip_posterior_variances._state with ExitStack() as es: es.enter_context(gpt_posterior_settings()) es.enter_context(fast_pred_var(True)) # we need to skip posterior variances here es.enter_context(skip_posterior_variances(True)) mvn = self(X) if observation_noise is not False: # TODO: implement Kronecker + diagonal solves so that this is possible. # if torch.is_tensor(observation_noise): # # TODO: Validate noise shape # # make observation_noise `batch_shape x q x n` # obs_noise = observation_noise.transpose(-1, -2) # mvn = self.likelihood(mvn, X, noise=obs_noise) # elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood): # noise = self.likelihood.noise.mean().expand(X.shape[:-1]) # mvn = self.likelihood(mvn, X, noise=noise) # else: mvn = self.likelihood(mvn, X) # lazy covariance matrix includes the interpolated version of the full # covariance matrix so we can actually grab that instead. if X.ndimension() > self.train_inputs[0].ndimension(): X_batch_shape = X.shape[:-2] train_inputs = self.train_inputs[0].reshape( *[1] * len(X_batch_shape), *self.train_inputs[0].shape) train_inputs = train_inputs.repeat( *X_batch_shape, *[1] * self.train_inputs[0].ndimension()) else: train_inputs = self.train_inputs[0] full_covar = self.covar_modules[0](torch.cat((train_inputs, X), dim=-2)) if no_pred_variance: pred_variance = mvn.variance else: # we detach all of the latent dimension posteriors which precludes # computing quantities computed on the posterior wrt latents as # this reduces the memory overhead somewhat # TODO: add these back in if necessary joint_covar = self._get_joint_covariance([X]) pred_variance = self.make_posterior_variances(joint_covar) full_covar = KroneckerProductLazyTensor( full_covar, *[x.detach() for x in joint_covar.lazy_tensors[1:]]) joint_covar_list = [self.covar_modules[0](X, train_inputs)] batch_shape = joint_covar_list[0].batch_shape for cm, param in zip(self.covar_modules[1:], self.latent_parameters): covar = cm(param).detach() if covar.batch_shape != batch_shape: covar = BatchRepeatLazyTensor(covar, batch_shape) joint_covar_list.append(covar) test_train_covar = KroneckerProductLazyTensor(*joint_covar_list) # mean and variance get reshaped into the target shape new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape) if not no_pred_variance: new_variance = pred_variance.reshape(*X.shape[:-1], *self.target_shape) new_variance = DiagLazyTensor(new_variance) else: new_variance = ZeroLazyTensor(*X.shape[:-1], *self.target_shape, self.target_shape[-1]) mvn = MultivariateNormal(new_mean, new_variance) train_train_covar = self.prediction_strategy.lik_train_train_covar.detach( ) # return a specialized Posterior to allow for sampling # cloning the full covar allows backpropagation through it posterior = HigherOrderGPPosterior( mvn=mvn, train_targets=self.train_targets.unsqueeze(-1), train_train_covar=train_train_covar, test_train_covar=test_train_covar, joint_covariance_matrix=full_covar.clone(), output_shape=X.shape[:-1] + self.target_shape, num_outputs=self._num_outputs, ) if hasattr(self, "outcome_transform"): posterior = self.outcome_transform.untransform_posterior( posterior) return posterior
def posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: Union[bool, Tensor] = False, **kwargs: Any, ) -> GPyTorchPosterior: self.eval() # make sure we're calling a posterior # input transforms are applied at `posterior` in `eval` mode, and at # `model.forward()` at the training time X = self.transform_inputs(X) no_pred_variance = skip_posterior_variances._state with ExitStack() as es: es.enter_context(gpt_posterior_settings()) es.enter_context(fast_pred_var(True)) # we need to skip posterior variances here es.enter_context(skip_posterior_variances(True)) mvn = self(X) if observation_noise is not False: # TODO: ensure that this still works for structured noise solves. mvn = self.likelihood(mvn, X) # lazy covariance matrix includes the interpolated version of the full # covariance matrix so we can actually grab that instead. if X.ndimension() > self.train_inputs[0].ndimension(): X_batch_shape = X.shape[:-2] train_inputs = self.train_inputs[0].reshape( *[1] * len(X_batch_shape), *self.train_inputs[0].shape ) train_inputs = train_inputs.repeat( *X_batch_shape, *[1] * self.train_inputs[0].ndimension() ) else: train_inputs = self.train_inputs[0] # we now compute the data covariances for the training data, the testing # data, the joint covariances, and the test train cross-covariance train_train_covar = self.prediction_strategy.lik_train_train_covar.detach() base_train_train_covar = train_train_covar.lazy_tensor data_train_covar = base_train_train_covar.lazy_tensors[0] data_covar = self.covar_modules[0] data_train_test_covar = data_covar(X, train_inputs) data_test_test_covar = data_covar(X) data_joint_covar = data_train_covar.cat_rows( cross_mat=data_train_test_covar, new_mat=data_test_test_covar, ) # we detach the latents so that they don't cause gradient errors # TODO: Can we enable backprop through the latent covariances? batch_shape = data_train_test_covar.batch_shape latent_covar_list = [] for latent_covar in base_train_train_covar.lazy_tensors[1:]: if latent_covar.batch_shape != batch_shape: latent_covar = BatchRepeatLazyTensor(latent_covar, batch_shape) latent_covar_list.append(latent_covar.detach()) joint_covar = KroneckerProductLazyTensor( data_joint_covar, *latent_covar_list ) test_train_covar = KroneckerProductLazyTensor( data_train_test_covar, *latent_covar_list ) # compute the posterior variance if necessary if no_pred_variance: pred_variance = mvn.variance else: pred_variance = self.make_posterior_variances(joint_covar) # mean and variance get reshaped into the target shape new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape) if not no_pred_variance: new_variance = pred_variance.reshape(*X.shape[:-1], *self.target_shape) new_variance = DiagLazyTensor(new_variance) else: new_variance = ZeroLazyTensor( *X.shape[:-1], *self.target_shape, self.target_shape[-1] ) mvn = MultivariateNormal(new_mean, new_variance) # return a specialized Posterior to allow for sampling # cloning the full covar allows backpropagation through it posterior = HigherOrderGPPosterior( mvn=mvn, train_targets=self.train_targets.unsqueeze(-1), train_train_covar=train_train_covar, test_train_covar=test_train_covar, joint_covariance_matrix=joint_covar.clone(), output_shape=X.shape[:-1] + self.target_shape, num_outputs=self._num_outputs, ) if hasattr(self, "outcome_transform"): posterior = self.outcome_transform.untransform_posterior(posterior) return posterior