Exemple #1
0
 def test_batch_range(self):
     # check batch_range default and can be changed
     sampler = IIDNormalSampler(num_samples=4)
     self.assertEquals(sampler.batch_range, (0, -2))
     sampler.batch_range = (-3, -2)
     self.assertEquals(sampler.batch_range, (-3, -2))
     # check that base samples are cleared after batch_range set
     posterior = _get_test_posterior(self.device)
     _ = sampler(posterior)
     self.assertNotEquals(sampler.base_samples, None)
     sampler.batch_range = (0, -2)
     self.assertEquals(sampler.base_samples, None)
Exemple #2
0
 def test_get_base_sample_shape_no_collapse(self):
     sampler = IIDNormalSampler(num_samples=4, collapse_batch_dims=False)
     self.assertFalse(sampler.resample)
     self.assertEqual(sampler.sample_shape, torch.Size([4]))
     self.assertFalse(sampler.collapse_batch_dims)
     # check sample shape non-batched
     posterior = _get_test_posterior(self.device)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 2, 1]))
     # check sample shape batched
     posterior = _get_test_posterior_batched(self.device)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 3, 2, 1]))
     # check sample shape with different batch range
     sampler.batch_range = (-3, -1)
     posterior = _get_test_posterior_batched(self.device)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 3, 2, 1]))
Exemple #3
0
    def test_forward(self):
        for dtype in (torch.float, torch.double):

            # no resample
            sampler = IIDNormalSampler(num_samples=4, seed=1234)
            self.assertFalse(sampler.resample)
            self.assertEqual(sampler.seed, 1234)
            self.assertTrue(sampler.collapse_batch_dims)
            # check samples non-batched
            posterior = _get_test_posterior(device=self.device, dtype=dtype)
            samples = sampler(posterior)
            self.assertEqual(samples.shape, torch.Size([4, 2, 1]))
            self.assertEqual(sampler.seed, 1235)
            # ensure samples are the same
            samples2 = sampler(posterior)
            self.assertTrue(torch.allclose(samples, samples2))
            self.assertEqual(sampler.seed, 1235)
            # ensure this works with a differently shaped posterior
            posterior_batched = _get_test_posterior_batched(device=self.device,
                                                            dtype=dtype)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
            self.assertEqual(sampler.seed, 1235)
            # ensure this works when changing the dtype
            new_dtype = torch.float if dtype == torch.double else torch.double
            posterior_batched = _get_test_posterior_batched(device=self.device,
                                                            dtype=new_dtype)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
            self.assertEqual(sampler.seed, 1235)
            # ensure this works with a different batch_range
            # should trigger a resample, so seed goes up
            sampler.batch_range = (-3, -1)
            posterior_batched = _get_test_posterior_batched(device=self.device,
                                                            dtype=dtype)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
            self.assertEqual(sampler.seed, 1236)

            # resample
            sampler = IIDNormalSampler(num_samples=4, resample=True, seed=None)
            self.assertTrue(sampler.resample)
            self.assertTrue(sampler.collapse_batch_dims)
            initial_seed = sampler.seed
            # check samples non-batched
            posterior = _get_test_posterior(device=self.device, dtype=dtype)
            samples = sampler(posterior)
            self.assertEqual(samples.shape, torch.Size([4, 2, 1]))
            self.assertEqual(sampler.seed, initial_seed + 1)
            # ensure samples are different
            samples2 = sampler(posterior)
            self.assertFalse(torch.allclose(samples, samples2))
            self.assertEqual(sampler.seed, initial_seed + 2)
            # ensure this works with a differently shaped posterior
            posterior_batched = _get_test_posterior_batched(device=self.device,
                                                            dtype=dtype)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
            self.assertEqual(sampler.seed, initial_seed + 3)
            # ensure this works with a different batch_range
            sampler.batch_range = (-3, -1)
            posterior_batched = _get_test_posterior_batched(device=self.device,
                                                            dtype=dtype)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
            self.assertEqual(sampler.seed, initial_seed + 4)