def test_multivariate_normal_batch_correlated_sampels(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        mean = torch.tensor([[0, 1], [2, 3]], dtype=torch.float, device=device).repeat(2, 1, 1)
        variance = 1 + torch.arange(4, dtype=torch.float, device=device)
        covmat = torch.diag(variance).repeat(2, 1, 1)
        mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat)

        base_samples = mtmvn.get_base_samples(torch.Size((3, 4)))
        self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 2, 2]))

        base_samples = mtmvn.get_base_samples()
        self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([2, 2, 2]))
예제 #2
0
 def test_multivariate_normal_correlated_sampels(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device)
         variance = torch.tensor([[1, 2], [3, 4]],
                                 dtype=dtype,
                                 device=device)
         covmat = variance.view(-1).diag()
         mtmvn = MultitaskMultivariateNormal(mean=mean,
                                             covariance_matrix=covmat)
         base_samples = mtmvn.get_base_samples(torch.Size([3, 4]))
         self.assertTrue(
             mtmvn.sample(base_samples=base_samples).shape == torch.Size(
                 [3, 4, 2, 2]))
         base_samples = mtmvn.get_base_samples()
         self.assertTrue(
             mtmvn.sample(
                 base_samples=base_samples).shape == torch.Size([2, 2]))
 def test_multivariate_normal_batch_correlated_samples(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([[0, 1], [2, 3], [4, 5]],
                             dtype=dtype,
                             device=device).repeat(2, 1, 1)
         variance = torch.tensor([[1, 2], [3, 4], [5, 6]],
                                 dtype=dtype,
                                 device=device).repeat(2, 1, 1)
         covmat = variance.view(2, 1, -1) * torch.eye(
             6, device=device, dtype=dtype)
         mtmvn = MultitaskMultivariateNormal(mean=mean,
                                             covariance_matrix=covmat)
         base_samples = mtmvn.get_base_samples(torch.Size((3, 4)))
         self.assertTrue(
             mtmvn.sample(base_samples=base_samples).shape == torch.Size(
                 [3, 4, 2, 3, 2]))
         base_samples = mtmvn.get_base_samples()
         self.assertTrue(
             mtmvn.sample(
                 base_samples=base_samples).shape == torch.Size([2, 3, 2]))