Exemple #1
0
    def test_multivariate_normal_batch_correlated_samples(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device)
        covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device))
        mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=NonLazyTensor(covmat).repeat(2, 1, 1))

        base_samples = mvn.get_base_samples(torch.Size((3, 4)))
        self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 3]))

        base_samples = mvn.get_base_samples()
        self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([2, 3]))
Exemple #2
0
 def test_multivariate_normal_correlated_samples(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([0, 1, 2], device=device, dtype=dtype)
         covmat = torch.diag(
             torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype))
         mvn = MultivariateNormal(mean=mean,
                                  covariance_matrix=NonLazyTensor(covmat))
         base_samples = mvn.get_base_samples(torch.Size([3, 4]))
         self.assertTrue(
             mvn.sample(
                 base_samples=base_samples).shape == torch.Size([3, 4, 3]))
         base_samples = mvn.get_base_samples()
         self.assertTrue(
             mvn.sample(base_samples=base_samples).shape == torch.Size([3]))