def test_multivariate_normal_batch_correlated_sampels(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        mean = torch.tensor([[0, 1], [2, 3]], dtype=torch.float, device=device).repeat(2, 1, 1)
        variance = 1 + torch.arange(4, dtype=torch.float, device=device)
        covmat = torch.diag(variance).repeat(2, 1, 1)
        mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat)

        base_samples = mtmvn.get_base_samples(torch.Size((3, 4)))
        self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 2, 2]))

        base_samples = mtmvn.get_base_samples()
        self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([2, 2, 2]))
示例#2
0
 def test_multitask_multivariate_normal_batch(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     mean = torch.tensor([[0, 1], [2, 3]], dtype=torch.float,
                         device=device).repeat(2, 1, 1)
     variance = 1 + torch.arange(4, dtype=torch.float, device=device)
     covmat = torch.diag(variance).repeat(2, 1, 1)
     mtmvn = MultitaskMultivariateNormal(mean=mean,
                                         covariance_matrix=covmat)
     self.assertTrue(torch.equal(mtmvn.mean, mean))
     self.assertTrue(
         approx_equal(mtmvn.variance,
                      variance.repeat(2, 1).view(2, 2, 2)))
     self.assertTrue(torch.equal(mtmvn.scale_tril, covmat.sqrt()))
     mvn_plus1 = mtmvn + 1
     self.assertTrue(torch.equal(mvn_plus1.mean, mtmvn.mean + 1))
     self.assertTrue(
         torch.equal(mvn_plus1.covariance_matrix, mtmvn.covariance_matrix))
     mvn_times2 = mtmvn * 2
     self.assertTrue(torch.equal(mvn_times2.mean, mtmvn.mean * 2))
     self.assertTrue(
         torch.equal(mvn_times2.covariance_matrix,
                     mtmvn.covariance_matrix * 4))
     mvn_divby2 = mtmvn / 2
     self.assertTrue(torch.equal(mvn_divby2.mean, mtmvn.mean / 2))
     self.assertTrue(
         torch.equal(mvn_divby2.covariance_matrix,
                     mtmvn.covariance_matrix / 4))
     self.assertTrue(
         approx_equal(mtmvn.entropy(),
                      7.2648 * torch.ones(2, device=device)))
     logprob = mtmvn.log_prob(torch.zeros(2, 2, 2, device=device))
     logprob_expected = -7.3064 * torch.ones(2, device=device)
     self.assertTrue(approx_equal(logprob, logprob_expected))
     logprob = mtmvn.log_prob(torch.zeros(3, 2, 2, 2, device=device))
     logprob_expected = -7.3064 * torch.ones(3, 2, device=device)
     self.assertTrue(approx_equal(logprob, logprob_expected))
     conf_lower, conf_upper = mtmvn.confidence_region()
     self.assertTrue(approx_equal(conf_lower,
                                  mtmvn.mean - 2 * mtmvn.stddev))
     self.assertTrue(approx_equal(conf_upper,
                                  mtmvn.mean + 2 * mtmvn.stddev))
     self.assertTrue(mtmvn.sample().shape == torch.Size([2, 2, 2]))
     self.assertTrue(
         mtmvn.sample(torch.Size([3])).shape == torch.Size([3, 2, 2, 2]))
     self.assertTrue(
         mtmvn.sample(torch.Size([3, 4])).shape == torch.Size(
             [3, 4, 2, 2, 2]))
示例#3
0
 def test_multivariate_normal_correlated_sampels(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device)
         variance = torch.tensor([[1, 2], [3, 4]],
                                 dtype=dtype,
                                 device=device)
         covmat = variance.view(-1).diag()
         mtmvn = MultitaskMultivariateNormal(mean=mean,
                                             covariance_matrix=covmat)
         base_samples = mtmvn.get_base_samples(torch.Size([3, 4]))
         self.assertTrue(
             mtmvn.sample(base_samples=base_samples).shape == torch.Size(
                 [3, 4, 2, 2]))
         base_samples = mtmvn.get_base_samples()
         self.assertTrue(
             mtmvn.sample(
                 base_samples=base_samples).shape == torch.Size([2, 2]))
 def test_multivariate_normal_batch_correlated_samples(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([[0, 1], [2, 3], [4, 5]],
                             dtype=dtype,
                             device=device).repeat(2, 1, 1)
         variance = torch.tensor([[1, 2], [3, 4], [5, 6]],
                                 dtype=dtype,
                                 device=device).repeat(2, 1, 1)
         covmat = variance.view(2, 1, -1) * torch.eye(
             6, device=device, dtype=dtype)
         mtmvn = MultitaskMultivariateNormal(mean=mean,
                                             covariance_matrix=covmat)
         base_samples = mtmvn.get_base_samples(torch.Size((3, 4)))
         self.assertTrue(
             mtmvn.sample(base_samples=base_samples).shape == torch.Size(
                 [3, 4, 2, 3, 2]))
         base_samples = mtmvn.get_base_samples()
         self.assertTrue(
             mtmvn.sample(
                 base_samples=base_samples).shape == torch.Size([2, 3, 2]))
    def test_multitask_multivariate_normal(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([[0, 1], [2, 3], [4, 5]],
                                dtype=dtype,
                                device=device)
            variance = torch.tensor([[1, 2], [3, 4], [5, 6]],
                                    dtype=dtype,
                                    device=device)

            # interleaved
            covmat = variance.view(-1).diag()
            mtmvn = MultitaskMultivariateNormal(mean=mean,
                                                covariance_matrix=covmat)
            self.assertTrue(torch.equal(mtmvn.mean, mean))
            self.assertTrue(torch.allclose(mtmvn.variance, variance))
            self.assertTrue(torch.allclose(mtmvn.scale_tril, covmat.sqrt()))
            self.assertTrue(mtmvn.event_shape == torch.Size([3, 2]))
            self.assertTrue(mtmvn.batch_shape == torch.Size())
            mvn_plus1 = mtmvn + 1
            self.assertTrue(torch.equal(mvn_plus1.mean, mtmvn.mean + 1))
            self.assertTrue(
                torch.equal(mvn_plus1.covariance_matrix,
                            mtmvn.covariance_matrix))
            mvn_times2 = mtmvn * 2
            self.assertTrue(torch.equal(mvn_times2.mean, mtmvn.mean * 2))
            self.assertTrue(
                torch.equal(mvn_times2.covariance_matrix,
                            mtmvn.covariance_matrix * 4))
            mvn_divby2 = mtmvn / 2
            self.assertTrue(torch.equal(mvn_divby2.mean, mtmvn.mean / 2))
            self.assertTrue(
                torch.equal(mvn_divby2.covariance_matrix,
                            mtmvn.covariance_matrix / 4))
            self.assertAlmostEqual(mtmvn.entropy().item(), 11.80326, places=4)
            self.assertAlmostEqual(mtmvn.log_prob(
                torch.zeros(3, 2, device=device, dtype=dtype)).item(),
                                   -14.52826,
                                   places=4)
            logprob = mtmvn.log_prob(
                torch.zeros(2, 3, 2, device=device, dtype=dtype))
            logprob_expected = -14.52826 * torch.ones(
                2, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(logprob, logprob_expected))
            conf_lower, conf_upper = mtmvn.confidence_region()
            self.assertTrue(
                torch.allclose(conf_lower, mtmvn.mean - 2 * mtmvn.stddev))
            self.assertTrue(
                torch.allclose(conf_upper, mtmvn.mean + 2 * mtmvn.stddev))
            self.assertTrue(mtmvn.sample().shape == torch.Size([3, 2]))
            self.assertTrue(
                mtmvn.sample(torch.Size([3])).shape == torch.Size([3, 3, 2]))
            self.assertTrue(
                mtmvn.sample(torch.Size([3, 4])).shape == torch.Size(
                    [3, 4, 3, 2]))

            # non-interleaved
            covmat = variance.transpose(-1, -2).reshape(-1).diag()
            mtmvn = MultitaskMultivariateNormal(mean=mean,
                                                covariance_matrix=covmat,
                                                interleaved=False)
            self.assertTrue(torch.equal(mtmvn.mean, mean))
            self.assertTrue(torch.allclose(mtmvn.variance, variance))
            self.assertTrue(torch.allclose(mtmvn.scale_tril, covmat.sqrt()))
            self.assertTrue(mtmvn.event_shape == torch.Size([3, 2]))
            self.assertTrue(mtmvn.batch_shape == torch.Size())