def test_multivariate_normal_batch_correlated_sampels(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([[0, 1], [2, 3]], dtype=torch.float, device=device).repeat(2, 1, 1) variance = 1 + torch.arange(4, dtype=torch.float, device=device) covmat = torch.diag(variance).repeat(2, 1, 1) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) base_samples = mtmvn.get_base_samples(torch.Size((3, 4))) self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 2, 2])) base_samples = mtmvn.get_base_samples() self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([2, 2, 2]))
def test_multitask_multivariate_normal_batch(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([[0, 1], [2, 3]], dtype=torch.float, device=device).repeat(2, 1, 1) variance = 1 + torch.arange(4, dtype=torch.float, device=device) covmat = torch.diag(variance).repeat(2, 1, 1) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) self.assertTrue(torch.equal(mtmvn.mean, mean)) self.assertTrue( approx_equal(mtmvn.variance, variance.repeat(2, 1).view(2, 2, 2))) self.assertTrue(torch.equal(mtmvn.scale_tril, covmat.sqrt())) mvn_plus1 = mtmvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mtmvn.mean + 1)) self.assertTrue( torch.equal(mvn_plus1.covariance_matrix, mtmvn.covariance_matrix)) mvn_times2 = mtmvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mtmvn.mean * 2)) self.assertTrue( torch.equal(mvn_times2.covariance_matrix, mtmvn.covariance_matrix * 4)) mvn_divby2 = mtmvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mtmvn.mean / 2)) self.assertTrue( torch.equal(mvn_divby2.covariance_matrix, mtmvn.covariance_matrix / 4)) self.assertTrue( approx_equal(mtmvn.entropy(), 7.2648 * torch.ones(2, device=device))) logprob = mtmvn.log_prob(torch.zeros(2, 2, 2, device=device)) logprob_expected = -7.3064 * torch.ones(2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) logprob = mtmvn.log_prob(torch.zeros(3, 2, 2, 2, device=device)) logprob_expected = -7.3064 * torch.ones(3, 2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) conf_lower, conf_upper = mtmvn.confidence_region() self.assertTrue(approx_equal(conf_lower, mtmvn.mean - 2 * mtmvn.stddev)) self.assertTrue(approx_equal(conf_upper, mtmvn.mean + 2 * mtmvn.stddev)) self.assertTrue(mtmvn.sample().shape == torch.Size([2, 2, 2])) self.assertTrue( mtmvn.sample(torch.Size([3])).shape == torch.Size([3, 2, 2, 2])) self.assertTrue( mtmvn.sample(torch.Size([3, 4])).shape == torch.Size( [3, 4, 2, 2, 2]))
def test_multivariate_normal_correlated_sampels(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device) variance = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device) covmat = variance.view(-1).diag() mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) base_samples = mtmvn.get_base_samples(torch.Size([3, 4])) self.assertTrue( mtmvn.sample(base_samples=base_samples).shape == torch.Size( [3, 4, 2, 2])) base_samples = mtmvn.get_base_samples() self.assertTrue( mtmvn.sample( base_samples=base_samples).shape == torch.Size([2, 2]))
def test_multivariate_normal_batch_correlated_samples(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=dtype, device=device).repeat(2, 1, 1) variance = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype, device=device).repeat(2, 1, 1) covmat = variance.view(2, 1, -1) * torch.eye( 6, device=device, dtype=dtype) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) base_samples = mtmvn.get_base_samples(torch.Size((3, 4))) self.assertTrue( mtmvn.sample(base_samples=base_samples).shape == torch.Size( [3, 4, 2, 3, 2])) base_samples = mtmvn.get_base_samples() self.assertTrue( mtmvn.sample( base_samples=base_samples).shape == torch.Size([2, 3, 2]))
def test_multitask_multivariate_normal(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=dtype, device=device) variance = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype, device=device) # interleaved covmat = variance.view(-1).diag() mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) self.assertTrue(torch.equal(mtmvn.mean, mean)) self.assertTrue(torch.allclose(mtmvn.variance, variance)) self.assertTrue(torch.allclose(mtmvn.scale_tril, covmat.sqrt())) self.assertTrue(mtmvn.event_shape == torch.Size([3, 2])) self.assertTrue(mtmvn.batch_shape == torch.Size()) mvn_plus1 = mtmvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mtmvn.mean + 1)) self.assertTrue( torch.equal(mvn_plus1.covariance_matrix, mtmvn.covariance_matrix)) mvn_times2 = mtmvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mtmvn.mean * 2)) self.assertTrue( torch.equal(mvn_times2.covariance_matrix, mtmvn.covariance_matrix * 4)) mvn_divby2 = mtmvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mtmvn.mean / 2)) self.assertTrue( torch.equal(mvn_divby2.covariance_matrix, mtmvn.covariance_matrix / 4)) self.assertAlmostEqual(mtmvn.entropy().item(), 11.80326, places=4) self.assertAlmostEqual(mtmvn.log_prob( torch.zeros(3, 2, device=device, dtype=dtype)).item(), -14.52826, places=4) logprob = mtmvn.log_prob( torch.zeros(2, 3, 2, device=device, dtype=dtype)) logprob_expected = -14.52826 * torch.ones( 2, device=device, dtype=dtype) self.assertTrue(torch.allclose(logprob, logprob_expected)) conf_lower, conf_upper = mtmvn.confidence_region() self.assertTrue( torch.allclose(conf_lower, mtmvn.mean - 2 * mtmvn.stddev)) self.assertTrue( torch.allclose(conf_upper, mtmvn.mean + 2 * mtmvn.stddev)) self.assertTrue(mtmvn.sample().shape == torch.Size([3, 2])) self.assertTrue( mtmvn.sample(torch.Size([3])).shape == torch.Size([3, 3, 2])) self.assertTrue( mtmvn.sample(torch.Size([3, 4])).shape == torch.Size( [3, 4, 3, 2])) # non-interleaved covmat = variance.transpose(-1, -2).reshape(-1).diag() mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat, interleaved=False) self.assertTrue(torch.equal(mtmvn.mean, mean)) self.assertTrue(torch.allclose(mtmvn.variance, variance)) self.assertTrue(torch.allclose(mtmvn.scale_tril, covmat.sqrt())) self.assertTrue(mtmvn.event_shape == torch.Size([3, 2])) self.assertTrue(mtmvn.batch_shape == torch.Size())