def test_construct_base_samples_from_posterior(self): # noqa: C901 for dtype in (torch.float, torch.double): # single-output mean = torch.zeros(2, device=self.device, dtype=dtype) cov = torch.eye(2, device=self.device, dtype=dtype) mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mvn) for sample_shape, qmc, seed in itertools.product( (torch.Size([5]), torch.Size([5, 3])), (False, True), (None, 1234) ): expected_shape = sample_shape + torch.Size([2, 1]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, seed=seed ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, self.device.type) self.assertEqual(samples.dtype, dtype) # single-output, batch mode mean = torch.zeros(2, 2, device=self.device, dtype=dtype) cov = torch.eye(2, device=self.device, dtype=dtype).expand(2, 2, 2) mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mvn) for sample_shape, qmc, seed, collapse_batch_dims in itertools.product( (torch.Size([5]), torch.Size([5, 3])), (False, True), (None, 1234), (False, True), ): if collapse_batch_dims: expected_shape = sample_shape + torch.Size([1, 2, 1]) else: expected_shape = sample_shape + torch.Size([2, 2, 1]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, collapse_batch_dims=collapse_batch_dims, seed=seed, ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, self.device.type) self.assertEqual(samples.dtype, dtype) # multi-output mean = torch.zeros(2, 2, device=self.device, dtype=dtype) cov = torch.eye(4, device=self.device, dtype=dtype) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mtmvn) for sample_shape, qmc, seed in itertools.product( (torch.Size([5]), torch.Size([5, 3])), (False, True), (None, 1234) ): expected_shape = sample_shape + torch.Size([2, 2]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, seed=seed ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, self.device.type) self.assertEqual(samples.dtype, dtype) # multi-output, batch mode mean = torch.zeros(2, 2, 2, device=self.device, dtype=dtype) cov = torch.eye(4, device=self.device, dtype=dtype).expand(2, 4, 4) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mtmvn) for sample_shape, qmc, seed, collapse_batch_dims in itertools.product( (torch.Size([5]), torch.Size([5, 3])), (False, True), (None, 1234), (False, True), ): if collapse_batch_dims: expected_shape = sample_shape + torch.Size([1, 2, 2]) else: expected_shape = sample_shape + torch.Size([2, 2, 2]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, collapse_batch_dims=collapse_batch_dims, seed=seed, ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, self.device.type) self.assertEqual(samples.dtype, dtype)
def test_construct_base_samples_from_posterior(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): # single-output mean = torch.zeros(2, device=device, dtype=dtype) cov = torch.eye(2, device=device, dtype=dtype) mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mvn) for sample_shape in (torch.Size([5]), torch.Size([5, 3])): for qmc in (False, True): for seed in (None, 1234): expected_shape = sample_shape + torch.Size([2, 1]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, seed=seed, ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, device.type) self.assertEqual(samples.dtype, dtype) # single-output, batch mode mean = torch.zeros(2, 2, device=device, dtype=dtype) cov = torch.eye(2, device=device, dtype=dtype).expand(2, 2, 2) mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mvn) for sample_shape in (torch.Size([5]), torch.Size([5, 3])): for qmc in (False, True): for seed in (None, 1234): for collapse_batch_dims in (False, True): if collapse_batch_dims: expected_shape = sample_shape + torch.Size([1, 2, 1]) else: expected_shape = sample_shape + torch.Size([2, 2, 1]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, collapse_batch_dims=collapse_batch_dims, seed=seed, ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, device.type) self.assertEqual(samples.dtype, dtype) # multi-output mean = torch.zeros(2, 2, device=device, dtype=dtype) cov = torch.eye(4, device=device, dtype=dtype) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mtmvn) for sample_shape in (torch.Size([5]), torch.Size([5, 3])): for qmc in (False, True): for seed in (None, 1234): expected_shape = sample_shape + torch.Size([2, 2]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, seed=seed, ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, device.type) self.assertEqual(samples.dtype, dtype) # multi-output, batch mode mean = torch.zeros(2, 2, 2, device=device, dtype=dtype) cov = torch.eye(4, device=device, dtype=dtype).expand(2, 4, 4) mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) posterior = GPyTorchPosterior(mvn=mtmvn) for sample_shape in (torch.Size([5]), torch.Size([5, 3])): for qmc in (False, True): for seed in (None, 1234): for collapse_batch_dims in (False, True): if collapse_batch_dims: expected_shape = sample_shape + torch.Size([1, 2, 2]) else: expected_shape = sample_shape + torch.Size([2, 2, 2]) samples = construct_base_samples_from_posterior( posterior=posterior, sample_shape=sample_shape, qmc=qmc, collapse_batch_dims=collapse_batch_dims, seed=seed, ) self.assertEqual(samples.shape, expected_shape) self.assertEqual(samples.device.type, device.type) self.assertEqual(samples.dtype, dtype)