def test_degenerate_GPyTorchPosterior_Multitask(self): for dtype in (torch.float, torch.double): # singular covariance matrix degenerate_covar = torch.tensor( [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=self.device ) mean = torch.rand(3, dtype=dtype, device=self.device) mvn = MultivariateNormal(mean, lazify(degenerate_covar)) mvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn]) posterior = GPyTorchPosterior(mvn=mvn) # basics self.assertEqual(posterior.device.type, self.device.type) self.assertTrue(posterior.dtype == dtype) self.assertEqual(posterior.event_shape, torch.Size([3, 2])) mean_exp = mean.unsqueeze(-1).repeat(1, 2) self.assertTrue(torch.equal(posterior.mean, mean_exp)) variance_exp = degenerate_covar.diag().unsqueeze(-1).repeat(1, 2) self.assertTrue(torch.equal(posterior.variance, variance_exp)) # rsample with warnings.catch_warnings(record=True) as ws: # we check that the p.d. warning is emitted - this only # happens once per posterior, so we need to check only once samples = posterior.rsample(sample_shape=torch.Size([4])) self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws)) self.assertTrue(any("not p.d" in str(w.message) for w in ws)) self.assertEqual(samples.shape, torch.Size([4, 3, 2])) samples2 = posterior.rsample(sample_shape=torch.Size([4, 2])) self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2])) # rsample w/ base samples base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype) samples_b1 = posterior.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) samples_b2 = posterior.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) self.assertTrue(torch.allclose(samples_b1, samples_b2)) base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype) samples2_b1 = posterior.rsample( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) samples2_b2 = posterior.rsample( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) self.assertTrue(torch.allclose(samples2_b1, samples2_b2)) # collapse_batch_dims b_mean = torch.rand(2, 3, dtype=dtype, device=self.device) b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape) b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar)) b_mvn = MultitaskMultivariateNormal.from_independent_mvns([b_mvn, b_mvn]) b_posterior = GPyTorchPosterior(mvn=b_mvn) b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype) with warnings.catch_warnings(record=True) as ws: b_samples = b_posterior.rsample( sample_shape=torch.Size([4]), base_samples=b_base_samples ) self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws)) self.assertTrue(any("not p.d" in str(w.message) for w in ws)) self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
def posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: Union[bool, Tensor] = False, **kwargs: Any, ) -> GPyTorchPosterior: r"""Computes the posterior over model outputs at the provided points. Args: X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the feature space and `q` is the number of points considered jointly. output_indices: A list of indices, corresponding to the outputs over which to compute the posterior (if the model is multi-output). Can be used to speed up computation if only a subset of the model's outputs are required for optimization. If omitted, computes the posterior over all model outputs. observation_noise: If True, add the observation noise from the likelihood to the posterior. If a Tensor, use it directly as the observation noise (must be of shape `(batch_shape) x q x m`). Returns: A `GPyTorchPosterior` object, representing `batch_shape` joint distributions over `q` points and the outputs selected by `output_indices` each. Includes observation noise if specified. """ self.eval() # make sure model is in eval mode with gpt_posterior_settings(): # insert a dimension for the output dimension if self._num_outputs > 1: X, output_dim_idx = add_output_dim( X=X, original_batch_shape=self._input_batch_shape) mvn = self(X) if observation_noise is not False: if torch.is_tensor(observation_noise): # TODO: Validate noise shape # make observation_noise `batch_shape x q x n` obs_noise = observation_noise.transpose(-1, -2) mvn = self.likelihood(mvn, X, noise=obs_noise) elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood): # Use the mean of the previous noise values (TODO: be smarter here). noise = self.likelihood.noise.mean().expand(X.shape[:-1]) mvn = self.likelihood(mvn, X, noise=noise) else: mvn = self.likelihood(mvn, X) if self._num_outputs > 1: mean_x = mvn.mean covar_x = mvn.covariance_matrix output_indices = output_indices or range(self._num_outputs) mvns = [ MultivariateNormal( mean_x.select(dim=output_dim_idx, index=t), lazify(covar_x.select(dim=output_dim_idx, index=t)), ) for t in output_indices ] mvn = MultitaskMultivariateNormal.from_independent_mvns( mvns=mvns) return GPyTorchPosterior(mvn=mvn)
def posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: bool = False, **kwargs: Any, ) -> GPyTorchPosterior: r"""Computes the posterior over model outputs at the provided points. Args: X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the feature space and `q` is the number of points considered jointly. output_indices: A list of indices, corresponding to the outputs over which to compute the posterior (if the model is multi-output). Can be used to speed up computation if only a subset of the model's outputs are required for optimization. If omitted, computes the posterior over all model outputs. observation_noise: If True, add observation noise to the posterior. propagate_grads: If True, do not detach GPyTorch's test caches when computing of the posterior. Required for being able to compute derivatives with respect to training inputs at test time (used e.g. by qNoisyExpectedImprovement). Defaults to `False`. Returns: A `GPyTorchPosterior` object, representing `batch_shape` joint distributions over `q` points and the outputs selected by `output_indices` each. Includes observation noise if `observation_noise=True`. """ self.eval() # make sure model is in eval mode detach_test_caches = not kwargs.get("propagate_grads", False) with ExitStack() as es: es.enter_context(settings.debug(False)) es.enter_context(settings.fast_pred_var()) es.enter_context(settings.detach_test_caches(detach_test_caches)) # insert a dimension for the output dimension if self._num_outputs > 1: X, output_dim_idx = add_output_dim( X=X, original_batch_shape=self._input_batch_shape ) mvn = self(X) if observation_noise: mvn = self.likelihood(mvn, X) if self._num_outputs > 1: mean_x = mvn.mean covar_x = mvn.covariance_matrix output_indices = output_indices or range(self._num_outputs) mvns = [ MultivariateNormal( mean_x.select(dim=output_dim_idx, index=t), lazify(covar_x.select(dim=output_dim_idx, index=t)), ) for t in output_indices ] mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) return GPyTorchPosterior(mvn=mvn)
def posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: bool = False, **kwargs: Any, ) -> GPyTorchPosterior: r"""Computes the posterior over model outputs at the provided points. Args: X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the feature space and `q` is the number of points considered jointly. output_indices: A list of indices, corresponding to the outputs over which to compute the posterior (if the model is multi-output). Can be used to speed up computation if only a subset of the model's outputs are required for optimization. If omitted, computes the posterior over all model outputs. observation_noise: If True, add observation noise to the posterior. detach_test_caches: If True, detach GPyTorch test caches during computation of the posterior. Required for being able to compute derivatives with respect to training inputs at test time (used e.g. by qNoisyExpectedImprovement). Defaults to `True`. Returns: A `GPyTorchPosterior` object, representing `batch_shape` joint distributions over `q` points and the outputs selected by `output_indices` each. Includes observation noise if `observation_noise=True`. """ self.eval() # make sure model is in eval mode detach_test_caches = kwargs.get("detach_test_caches", True) with ExitStack() as es: es.enter_context(settings.debug(False)) es.enter_context(settings.fast_pred_var()) es.enter_context(settings.detach_test_caches(detach_test_caches)) # insert a dimension for the output dimension if self._num_outputs > 1: X, output_dim_idx = add_output_dim( X=X, original_batch_shape=self._input_batch_shape ) mvn = self(X) mean_x = mvn.mean covar_x = mvn.covariance_matrix if self._num_outputs > 1: output_indices = output_indices or range(self._num_outputs) mvns = [ MultivariateNormal( mean_x.select(dim=output_dim_idx, index=t), lazify(covar_x.select(dim=output_dim_idx, index=t)), ) for t in output_indices ] mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) return GPyTorchPosterior(mvn=mvn)
def posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: bool = False, **kwargs: Any, ) -> GPyTorchPosterior: r"""Computes the posterior over model outputs at the provided points. Args: X: A `b x q x d`-dim Tensor, where `d` is the dimension of the feature space, `q` is the number of points considered jointly, and `b` is the batch dimension. output_indices: A list of indices, corresponding to the outputs over which to compute the posterior (if the model is multi-output). Can be used to speed up computation if only a subset of the model's outputs are required for optimization. If omitted, computes the posterior over all model outputs. observation_noise: If True, add observation noise to the posterior. detach_test_caches: If True, detach GPyTorch test caches during computation of the posterior. Required for being able to compute derivatives with respect to training inputs at test time (used e.g. by qNoisyExpectedImprovement). Returns: A `GPyTorchPosterior` object, representing `batch_shape` joint distributions over `q` points and the outputs selected by `output_indices` each. Includes measurement noise if `observation_noise=True`. """ detach_test_caches = kwargs.get("detach_test_caches", True) self.eval() # make sure model is in eval mode with ExitStack() as es: es.enter_context(settings.debug(False)) es.enter_context(settings.fast_pred_var()) es.enter_context(settings.detach_test_caches(detach_test_caches)) if output_indices is not None: mvns = [self.forward_i(i, X) for i in output_indices] if observation_noise: mvns = [ self.likelihood_i(i, mvn, X) for i, mvn in zip(output_indices, mvns) ] else: mvns = self(*[X for _ in range(self.num_outputs)]) if observation_noise: # TODO: Allow passing in observation noise via kwarg mvns = self.likelihood(*[(mvn, X) for mvn in mvns]) if len(mvns) == 1: return GPyTorchPosterior(mvn=mvns[0]) else: return GPyTorchPosterior( mvn=MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) )
def posterior( self, X: Tensor, observation_noise: bool = False, posterior_transform: Optional[PosteriorTransform] = None, ) -> MockPosterior: m_shape = X.shape[:-1] r_shape = list(X.shape[:-2]) + [1, 1] mvn = MultivariateNormal( mean=torch.zeros(m_shape, dtype=X.dtype, device=X.device), covariance_matrix=torch.eye(m_shape[-1], dtype=X.dtype, device=X.device).repeat(r_shape), ) if self.num_outputs > 1: mvn = mvn = MultitaskMultivariateNormal.from_independent_mvns( mvns=[mvn] * self.num_outputs) posterior = GPyTorchPosterior(mvn) if posterior_transform is not None: return posterior_transform(posterior) return posterior
def forward(self, indices=None): """ Return the variational posterior for the latent variables, pertaining to provided indices """ if indices is None: ms = self.variational_mean vs = self.variational_variance else: ms = self.variational_mean[indices] vs = self.variational_variance[indices] vs = vs.expand(len(vs), self.output_dims) if self.output_dims == 1: m, = ms v, = vs return MultivariateNormal(m, DiagLazyTensor(v)) else: mvns = [MultivariateNormal(m, DiagLazyTensor(v)) for m, v in zip(ms.T, vs.T)] return MultitaskMultivariateNormal.from_independent_mvns(mvns)
def test_degenerate_GPyTorchPosterior_Multitask(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): # singular covariance matrix degenerate_covar = torch.tensor( [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=device ) mean = torch.rand(3, dtype=dtype, device=device) mvn = MultivariateNormal(mean, lazify(degenerate_covar)) mvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn]) posterior = GPyTorchPosterior(mvn=mvn) # basics self.assertEqual(posterior.device.type, device.type) self.assertTrue(posterior.dtype == dtype) self.assertEqual(posterior.event_shape, torch.Size([3, 2])) mean_exp = mean.unsqueeze(-1).repeat(1, 2) self.assertTrue(torch.equal(posterior.mean, mean_exp)) variance_exp = degenerate_covar.diag().unsqueeze(-1).repeat(1, 2) self.assertTrue(torch.equal(posterior.variance, variance_exp)) # rsample with warnings.catch_warnings(record=True) as w: # we check that the p.d. warning is emitted - this only # happens once per posterior, so we need to check only once samples = posterior.rsample(sample_shape=torch.Size([4])) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, RuntimeWarning)) self.assertTrue("not p.d." in str(w[-1].message)) self.assertEqual(samples.shape, torch.Size([4, 3, 2])) samples2 = posterior.rsample(sample_shape=torch.Size([4, 2])) self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2])) # rsample w/ base samples base_samples = torch.randn(4, 3, 2, device=device, dtype=dtype) samples_b1 = posterior.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) samples_b2 = posterior.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) self.assertTrue(torch.allclose(samples_b1, samples_b2)) base_samples2 = torch.randn(4, 2, 3, 2, device=device, dtype=dtype) samples2_b1 = posterior.rsample( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) samples2_b2 = posterior.rsample( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) self.assertTrue(torch.allclose(samples2_b1, samples2_b2)) # collapse_batch_dims b_mean = torch.rand(2, 3, dtype=dtype, device=device) b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape) b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar)) b_mvn = MultitaskMultivariateNormal.from_independent_mvns([b_mvn, b_mvn]) b_posterior = GPyTorchPosterior(mvn=b_mvn) b_base_samples = torch.randn(4, 1, 3, 2, device=device, dtype=dtype) with warnings.catch_warnings(record=True) as w: b_samples = b_posterior.rsample( sample_shape=torch.Size([4]), base_samples=b_base_samples ) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, RuntimeWarning)) self.assertTrue("not p.d." in str(w[-1].message)) self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))