Example #1
0
 def test_get_base_sample_shape(self):
     sampler = PairwiseSobolQMCNormalSampler(num_samples=4)
     self.assertFalse(sampler.resample)
     self.assertEqual(sampler.sample_shape, torch.Size([4]))
     self.assertTrue(sampler.collapse_batch_dims)
     # check sample shape non-batched
     posterior = _get_test_posterior(self.device)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 3, 1]))
     # check sample shape batched
     posterior = _get_test_posterior(self.device, batched=True)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 1, 3, 1]))
Example #2
0
    def test_fantasize(self):
        for batch_shape, dtype in itertools.product(
            (torch.Size(), torch.Size([2])), (torch.float, torch.double)):
            tkwargs = {"device": self.device, "dtype": dtype}
            X_dim = 2

            model, model_kwargs = self._get_model_and_data(
                batch_shape=batch_shape, X_dim=X_dim, **tkwargs)

            # fantasize
            X_f = torch.rand(torch.Size(batch_shape + torch.Size([4, X_dim])),
                             **tkwargs)
            sampler = PairwiseSobolQMCNormalSampler(num_samples=3)
            fm = model.fantasize(X=X_f, sampler=sampler)
            self.assertIsInstance(fm, model.__class__)
            fm = model.fantasize(X=X_f,
                                 sampler=sampler,
                                 observation_noise=False)
            self.assertIsInstance(fm, model.__class__)
Example #3
0
    def test_forward_no_collapse(self):
        for dtype in (torch.float, torch.double):
            # no resample
            sampler = PairwiseSobolQMCNormalSampler(num_samples=4,
                                                    seed=1234,
                                                    collapse_batch_dims=False)
            self.assertFalse(sampler.resample)
            self.assertEqual(sampler.seed, 1234)
            self.assertFalse(sampler.collapse_batch_dims)
            # check samples non-batched
            posterior = _get_test_posterior(device=self.device, dtype=dtype)
            samples = sampler(posterior)
            self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
            self.assertEqual(sampler.seed, 1235)
            # ensure samples are the same
            samples2 = sampler(posterior)
            self.assertTrue(torch.allclose(samples, samples2))
            self.assertEqual(sampler.seed, 1235)
            # ensure this works with a differently shaped posterior
            posterior_batched = _get_test_posterior(device=self.device,
                                                    dtype=dtype,
                                                    batched=True)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 3, 2]))
            self.assertEqual(sampler.seed, 1236)
            # ensure error is rasied when number of points are < 2
            posterior = _get_test_posterior(device=self.device,
                                            n=1,
                                            dtype=dtype)
            with self.assertRaises(RuntimeError):
                sampler(posterior)

            # resample
            sampler = PairwiseSobolQMCNormalSampler(num_samples=4,
                                                    resample=True,
                                                    collapse_batch_dims=False)
            self.assertTrue(sampler.resample)
            self.assertFalse(sampler.collapse_batch_dims)
            initial_seed = sampler.seed
            # check samples non-batched
            posterior = _get_test_posterior(device=self.device, dtype=dtype)
            samples = sampler(posterior=posterior)
            self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
            self.assertEqual(sampler.seed, initial_seed + 1)
            # ensure samples are not the same
            samples2 = sampler(posterior)
            self.assertFalse(torch.allclose(samples, samples2))
            self.assertEqual(sampler.seed, initial_seed + 2)
            # ensure this works with a differeantly shaped posterior
            posterior_batched = _get_test_posterior(device=self.device,
                                                    dtype=dtype,
                                                    batched=True)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 3, 2]))
            self.assertEqual(sampler.seed, initial_seed + 3)
            # ensure error is rasied when number of points are < 2
            posterior = _get_test_posterior(device=self.device,
                                            n=1,
                                            dtype=dtype)
            with self.assertRaises(RuntimeError):
                sampler(posterior)

            # check max_num_comparisons
            sampler = PairwiseSobolQMCNormalSampler(num_samples=4,
                                                    max_num_comparisons=2,
                                                    collapse_batch_dims=False)
            self.assertFalse(sampler.resample)
            self.assertFalse(sampler.collapse_batch_dims)
            # check samples non-batched
            posterior = _get_test_posterior(device=self.device, dtype=dtype)
            samples = sampler(posterior)
            self.assertEqual(samples.shape, torch.Size([4, 2, 2]))