def __call__(self, function, *params, **kwargs): if isinstance(function, Distribution) and not isinstance( function, MultitaskMultivariateNormal): warnings.warn( "The input to DeepGaussianLikelihood should be a MultitaskMultivariateNormal (num_data x num_tasks). " "Batch MultivariateNormal inputs (num_tasks x num_data) will be deprectated.", DeprecationWarning, ) function = MultitaskMultivariateNormal.from_batch_mvn(function) return super().__call__(function, *params, **kwargs)
def test_multitask_from_batch(self): mean = torch.randn(2, 3) variance = torch.randn(2, 3).clamp_min(1e-6) mvn = MultivariateNormal(mean, DiagLazyTensor(variance)) mmvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=-1) self.assertTrue(isinstance(mmvn, MultitaskMultivariateNormal)) self.assertEqual(mmvn.batch_shape, torch.Size([])) self.assertEqual(mmvn.event_shape, torch.Size([3, 2])) self.assertEqual(mmvn.covariance_matrix.shape, torch.Size([6, 6])) self.assertEqual(mmvn.mean, mean.transpose(-1, -2)) self.assertEqual(mmvn.variance, variance.transpose(-1, -2)) mean = torch.randn(2, 4, 3) variance = torch.randn(2, 4, 3).clamp_min(1e-6) mvn = MultivariateNormal(mean, DiagLazyTensor(variance)) mmvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0) self.assertTrue(isinstance(mmvn, MultitaskMultivariateNormal)) self.assertEqual(mmvn.batch_shape, torch.Size([4])) self.assertEqual(mmvn.event_shape, torch.Size([3, 2])) self.assertEqual(mmvn.covariance_matrix.shape, torch.Size([4, 6, 6])) self.assertEqual(mmvn.mean, mean.permute(1, 2, 0)) self.assertEqual(mmvn.variance, variance.permute(1, 2, 0))
def forward(self, X, **kwargs): # We require the provided X to be a set of indices that maps to our # latent input. assert X.dim() == 1 and torch.all(X == X.to(torch.int64)) X = self.latent_layer(indices=X.to(torch.int64)) mvn = super().forward(X) # Optionally make projection. # TODO covariance has not been implemented. if self.L is not None: proj_mean = (mvn.mean @ self.L).T proj_std = (mvn.stddev @ self.L).T mvn = MultivariateNormal(proj_mean, DiagLazyTensor(proj_std**2)) mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn) return mvn
def forward(self, x): """Forward pass method for making predictions through the model. The mean and covariance are each computed to produce a MV distribution. Parameters: x (torch.tensor): The tensor for which we predict a mean and covariance used the BatchedGP model. Returns: mv_normal (gpytorch.distributions.MultivariateNormal): A Multivariate Normal distribution with parameters for mean and covariance computed at x. """ mean_x = self.mean_module(x) # Compute the mean at x covar_x = self.covar_module(x) # Compute the covariance at x return MultitaskMultivariateNormal.from_batch_mvn( MultivariateNormal(mean_x, covar_x))
def __call__(self, X, Y=None, likelihood=None, prior=False): strat = self.variational_strategy if isinstance(strat, CollapsedStrategy): res = strat(X, Y, likelihood=likelihood, prior=prior) else: res = strat(X, prior=prior) N, D = len(X), self.output_dims if D == 1: mean = res.mean.reshape(N) assert res._covar.shape == (N, N) res = MultivariateNormal(mean, res._covar) else: assert res.mean.size(0) == D res = MultitaskMultivariateNormal.from_batch_mvn(res) return res
def _create_marginal_input(self, batch_shape=torch.Size([])): mat = torch.randn(*batch_shape, 6, 5, 5) return MultitaskMultivariateNormal.from_batch_mvn( MultivariateNormal(torch.randn(*batch_shape, 6, 5), mat @ mat.transpose(-1, -2)))
def forward(self, x): mean_x = self.mean_module(x) covar_x = self.covar_module(x) return MultitaskMultivariateNormal.from_batch_mvn( MultivariateNormal(mean_x, covar_x))
def forward(self, input): mean = self.mean_module(input) covar = self.covar_module(input) return MultitaskMultivariateNormal.from_batch_mvn(MultivariateNormal(mean, covar))