예제 #1
0
 def __call__(self, function, *params, **kwargs):
     if isinstance(function, Distribution) and not isinstance(
             function, MultitaskMultivariateNormal):
         warnings.warn(
             "The input to DeepGaussianLikelihood should be a MultitaskMultivariateNormal (num_data x num_tasks). "
             "Batch MultivariateNormal inputs (num_tasks x num_data) will be deprectated.",
             DeprecationWarning,
         )
         function = MultitaskMultivariateNormal.from_batch_mvn(function)
     return super().__call__(function, *params, **kwargs)
    def test_multitask_from_batch(self):
        mean = torch.randn(2, 3)
        variance = torch.randn(2, 3).clamp_min(1e-6)
        mvn = MultivariateNormal(mean, DiagLazyTensor(variance))
        mmvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=-1)
        self.assertTrue(isinstance(mmvn, MultitaskMultivariateNormal))
        self.assertEqual(mmvn.batch_shape, torch.Size([]))
        self.assertEqual(mmvn.event_shape, torch.Size([3, 2]))
        self.assertEqual(mmvn.covariance_matrix.shape, torch.Size([6, 6]))
        self.assertEqual(mmvn.mean, mean.transpose(-1, -2))
        self.assertEqual(mmvn.variance, variance.transpose(-1, -2))

        mean = torch.randn(2, 4, 3)
        variance = torch.randn(2, 4, 3).clamp_min(1e-6)
        mvn = MultivariateNormal(mean, DiagLazyTensor(variance))
        mmvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0)
        self.assertTrue(isinstance(mmvn, MultitaskMultivariateNormal))
        self.assertEqual(mmvn.batch_shape, torch.Size([4]))
        self.assertEqual(mmvn.event_shape, torch.Size([3, 2]))
        self.assertEqual(mmvn.covariance_matrix.shape, torch.Size([4, 6, 6]))
        self.assertEqual(mmvn.mean, mean.permute(1, 2, 0))
        self.assertEqual(mmvn.variance, variance.permute(1, 2, 0))
예제 #3
0
    def forward(self, X, **kwargs):
        # We require the provided X to be a set of indices that maps to our
        # latent input.
        assert X.dim() == 1 and torch.all(X == X.to(torch.int64))
        X = self.latent_layer(indices=X.to(torch.int64))
        mvn = super().forward(X)

        # Optionally make projection.
        # TODO covariance has not been implemented.
        if self.L is not None:
            proj_mean = (mvn.mean @ self.L).T
            proj_std = (mvn.stddev @ self.L).T
            mvn = MultivariateNormal(proj_mean, DiagLazyTensor(proj_std**2))
            mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn)

        return mvn
    def forward(self, x):
        """Forward pass method for making predictions through the model.  The
        mean and covariance are each computed to produce a MV distribution.

        Parameters:
            x (torch.tensor): The tensor for which we predict a mean and
                covariance used the BatchedGP model.

        Returns:
            mv_normal (gpytorch.distributions.MultivariateNormal): A Multivariate
                Normal distribution with parameters for mean and covariance computed
                at x.
        """
        mean_x = self.mean_module(x)  # Compute the mean at x
        covar_x = self.covar_module(x)  # Compute the covariance at x
        return MultitaskMultivariateNormal.from_batch_mvn(
            MultivariateNormal(mean_x, covar_x))
예제 #5
0
    def __call__(self, X, Y=None, likelihood=None, prior=False):
        strat = self.variational_strategy
        if isinstance(strat, CollapsedStrategy):
            res = strat(X, Y, likelihood=likelihood, prior=prior)
        else:
            res = strat(X, prior=prior)

        N, D = len(X), self.output_dims
        if D == 1:
            mean = res.mean.reshape(N)
            assert res._covar.shape == (N, N)
            res = MultivariateNormal(mean, res._covar)
        else:
            assert res.mean.size(0) == D
            res = MultitaskMultivariateNormal.from_batch_mvn(res)

        return res
 def _create_marginal_input(self, batch_shape=torch.Size([])):
     mat = torch.randn(*batch_shape, 6, 5, 5)
     return MultitaskMultivariateNormal.from_batch_mvn(
         MultivariateNormal(torch.randn(*batch_shape, 6, 5),
                            mat @ mat.transpose(-1, -2)))
예제 #7
0
 def forward(self, x):
     mean_x = self.mean_module(x)
     covar_x = self.covar_module(x)
     return MultitaskMultivariateNormal.from_batch_mvn(
         MultivariateNormal(mean_x, covar_x))
예제 #8
0
파일: gp_list_torch.py 프로젝트: pnickl/reg
 def forward(self, input):
     mean = self.mean_module(input)
     covar = self.covar_module(input)
     return MultitaskMultivariateNormal.from_batch_mvn(MultivariateNormal(mean, covar))