def test_batch_range(self): # check batch_range default and can be changed sampler = IIDNormalSampler(num_samples=4) self.assertEquals(sampler.batch_range, (0, -2)) sampler.batch_range = (-3, -2) self.assertEquals(sampler.batch_range, (-3, -2)) # check that base samples are cleared after batch_range set posterior = _get_test_posterior(self.device) _ = sampler(posterior) self.assertNotEquals(sampler.base_samples, None) sampler.batch_range = (0, -2) self.assertEquals(sampler.base_samples, None)
def test_get_base_sample_shape_no_collapse(self): sampler = IIDNormalSampler(num_samples=4, collapse_batch_dims=False) self.assertFalse(sampler.resample) self.assertEqual(sampler.sample_shape, torch.Size([4])) self.assertFalse(sampler.collapse_batch_dims) # check sample shape non-batched posterior = _get_test_posterior(self.device) bss = sampler._get_base_sample_shape(posterior=posterior) self.assertEqual(bss, torch.Size([4, 2, 1])) # check sample shape batched posterior = _get_test_posterior_batched(self.device) bss = sampler._get_base_sample_shape(posterior=posterior) self.assertEqual(bss, torch.Size([4, 3, 2, 1])) # check sample shape with different batch range sampler.batch_range = (-3, -1) posterior = _get_test_posterior_batched(self.device) bss = sampler._get_base_sample_shape(posterior=posterior) self.assertEqual(bss, torch.Size([4, 3, 2, 1]))
def test_forward(self): for dtype in (torch.float, torch.double): # no resample sampler = IIDNormalSampler(num_samples=4, seed=1234) self.assertFalse(sampler.resample) self.assertEqual(sampler.seed, 1234) self.assertTrue(sampler.collapse_batch_dims) # check samples non-batched posterior = _get_test_posterior(device=self.device, dtype=dtype) samples = sampler(posterior) self.assertEqual(samples.shape, torch.Size([4, 2, 1])) self.assertEqual(sampler.seed, 1235) # ensure samples are the same samples2 = sampler(posterior) self.assertTrue(torch.allclose(samples, samples2)) self.assertEqual(sampler.seed, 1235) # ensure this works with a differently shaped posterior posterior_batched = _get_test_posterior_batched(device=self.device, dtype=dtype) samples_batched = sampler(posterior_batched) self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1])) self.assertEqual(sampler.seed, 1235) # ensure this works when changing the dtype new_dtype = torch.float if dtype == torch.double else torch.double posterior_batched = _get_test_posterior_batched(device=self.device, dtype=new_dtype) samples_batched = sampler(posterior_batched) self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1])) self.assertEqual(sampler.seed, 1235) # ensure this works with a different batch_range # should trigger a resample, so seed goes up sampler.batch_range = (-3, -1) posterior_batched = _get_test_posterior_batched(device=self.device, dtype=dtype) samples_batched = sampler(posterior_batched) self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1])) self.assertEqual(sampler.seed, 1236) # resample sampler = IIDNormalSampler(num_samples=4, resample=True, seed=None) self.assertTrue(sampler.resample) self.assertTrue(sampler.collapse_batch_dims) initial_seed = sampler.seed # check samples non-batched posterior = _get_test_posterior(device=self.device, dtype=dtype) samples = sampler(posterior) self.assertEqual(samples.shape, torch.Size([4, 2, 1])) self.assertEqual(sampler.seed, initial_seed + 1) # ensure samples are different samples2 = sampler(posterior) self.assertFalse(torch.allclose(samples, samples2)) self.assertEqual(sampler.seed, initial_seed + 2) # ensure this works with a differently shaped posterior posterior_batched = _get_test_posterior_batched(device=self.device, dtype=dtype) samples_batched = sampler(posterior_batched) self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1])) self.assertEqual(sampler.seed, initial_seed + 3) # ensure this works with a different batch_range sampler.batch_range = (-3, -1) posterior_batched = _get_test_posterior_batched(device=self.device, dtype=dtype) samples_batched = sampler(posterior_batched) self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1])) self.assertEqual(sampler.seed, initial_seed + 4)