def test_multivariate_normal_batch_correlated_samples(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device) covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device)) mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=NonLazyTensor(covmat).repeat(2, 1, 1)) base_samples = mvn.get_base_samples(torch.Size((3, 4))) self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 3])) base_samples = mvn.get_base_samples() self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([2, 3]))
def test_multivariate_normal_correlated_samples(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)) mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat)) base_samples = mvn.get_base_samples(torch.Size([3, 4])) self.assertTrue( mvn.sample( base_samples=base_samples).shape == torch.Size([3, 4, 3])) base_samples = mvn.get_base_samples() self.assertTrue( mvn.sample(base_samples=base_samples).shape == torch.Size([3]))