Exemple #1
0
 def test_get_base_sample_shape_no_collapse(self):
     sampler = IIDNormalSampler(num_samples=4, collapse_batch_dims=False)
     self.assertFalse(sampler.resample)
     self.assertEqual(sampler.sample_shape, torch.Size([4]))
     self.assertFalse(sampler.collapse_batch_dims)
     # check sample shape non-batched
     posterior = _get_posterior()
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 2, 1]))
     # check sample shape batched
     posterior = _get_posterior_batched()
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 3, 2, 1]))
Exemple #2
0
 def test_get_base_sample_shape(self):
     sampler = IIDNormalSampler(num_samples=4)
     self.assertFalse(sampler.resample)
     self.assertEqual(sampler.sample_shape, torch.Size([4]))
     self.assertTrue(sampler.collapse_batch_dims)
     # check sample shape non-batched
     posterior = _get_test_posterior(self.device)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 2, 1]))
     # check sample shape batched
     posterior = _get_test_posterior_batched(self.device)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 1, 2, 1]))
     # check sample shape with different batch range
     sampler.batch_range = (-3, -1)
     posterior = _get_test_posterior_batched(self.device)
     bss = sampler._get_base_sample_shape(posterior=posterior)
     self.assertEqual(bss, torch.Size([4, 1, 1, 1]))