def test_multivariate_normal_batch_correlated_sampels(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([[0, 1], [2, 3]], dtype=torch.float, device=device).repeat(2, 1, 1) variance = 1 + torch.arange(4, dtype=torch.float, device=device) covmat = torch.diag(variance).repeat(2, 1, 1) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) base_samples = mtmvn.get_base_samples(torch.Size((3, 4))) self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 2, 2])) base_samples = mtmvn.get_base_samples() self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([2, 2, 2]))
def test_multivariate_normal_correlated_sampels(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device) variance = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device) covmat = variance.view(-1).diag() mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) base_samples = mtmvn.get_base_samples(torch.Size([3, 4])) self.assertTrue( mtmvn.sample(base_samples=base_samples).shape == torch.Size( [3, 4, 2, 2])) base_samples = mtmvn.get_base_samples() self.assertTrue( mtmvn.sample( base_samples=base_samples).shape == torch.Size([2, 2]))
def test_multivariate_normal_batch_correlated_samples(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=dtype, device=device).repeat(2, 1, 1) variance = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype, device=device).repeat(2, 1, 1) covmat = variance.view(2, 1, -1) * torch.eye( 6, device=device, dtype=dtype) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat) base_samples = mtmvn.get_base_samples(torch.Size((3, 4))) self.assertTrue( mtmvn.sample(base_samples=base_samples).shape == torch.Size( [3, 4, 2, 3, 2])) base_samples = mtmvn.get_base_samples() self.assertTrue( mtmvn.sample( base_samples=base_samples).shape == torch.Size([2, 3, 2]))