def test_normal_prior_batch_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0.0, 1.0], device=device) variance = torch.tensor([1.0, 2.0], device=device) prior = NormalPrior(mean, variance) dist = Normal(mean, variance) t = torch.zeros(2, device=device) self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) t = torch.zeros(2, 2, device=device) self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) with self.assertRaises(RuntimeError): prior.log_prob(torch.zeros(3, device=device)) mean = torch.tensor([[0.0, 1.0], [-1.0, 2.0]], device=device) variance = torch.tensor([[1.0, 2.0], [0.5, 1.0]], device=device) prior = NormalPrior(mean, variance) dist = Normal(mean, variance) t = torch.zeros(2, device=device) self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) t = torch.zeros(2, 2, device=device) self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) with self.assertRaises(RuntimeError): prior.log_prob(torch.zeros(3, device=device)) with self.assertRaises(RuntimeError): prior.log_prob(torch.zeros(2, 3, device=device))
def test_normal_prior_log_prob_log_transform(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor(0.0, device=device) variance = torch.tensor(1.0, device=device) prior = NormalPrior(mean, variance, transform=torch.exp) dist = Normal(mean, variance) t = torch.tensor(0.0, device=device) self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp()))) t = torch.tensor([-1, 0.5], device=device) self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp()))) t = torch.tensor([[-1, 0.5], [0.1, -2.0]], device=device) self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
def test_scalar_normal_prior_log_transform(self): prior = NormalPrior(0, 1, log_transform=True) self.assertTrue(prior.log_transform) self.assertAlmostEqual(prior.log_prob(prior.loc.new([0.0])).item(), math.log(1 / math.sqrt(2 * math.pi) * math.exp(-0.5)), places=5)
def test_scalar_normal_prior(self): prior = NormalPrior(0, 1) self.assertFalse(prior.log_transform) self.assertTrue(prior.is_in_support(torch.rand(1))) self.assertEqual(prior.shape, torch.Size([1])) self.assertEqual(prior.loc.item(), 0.0) self.assertEqual(prior.scale.item(), 1.0) self.assertAlmostEqual(prior.log_prob(prior.loc.new([0.0])).item(), math.log(1 / math.sqrt(2 * math.pi)), places=5)
def test_vector_normal_prior_size(self): prior = NormalPrior(0, 1, size=2) self.assertFalse(prior.log_transform) self.assertTrue(prior.is_in_support(torch.zeros(1))) self.assertEqual(prior.shape, torch.Size([2])) self.assertTrue(torch.equal(prior.loc, torch.tensor([0.0, 0.0]))) self.assertTrue(torch.equal(prior.scale, torch.tensor([1.0, 1.0]))) parameter = torch.tensor([1.0, 2.0]) self.assertAlmostEqual( prior.log_prob(parameter).item(), 2 * math.log(1 / math.sqrt(2 * math.pi)) - 0.5 * (parameter ** 2).sum().item(), places=5, )
def test_vector_normal_prior(self): prior = NormalPrior(torch.tensor([-0.5, 0.5]), torch.tensor([0.5, 1.0])) self.assertFalse(prior.log_transform) self.assertTrue(prior.is_in_support(torch.rand(1))) self.assertEqual(prior.shape, torch.Size([2])) self.assertTrue(torch.equal(prior.loc, prior.loc.new([-0.5, 0.5]))) self.assertTrue(torch.equal(prior.scale, prior.scale.new([0.5, 1.0]))) parameter = prior.loc.new([1.0, 2.0]) expected_log_prob = ( ((1 / math.sqrt(2 * math.pi) / prior.scale).log() - 0.5 / prior.scale**2 * (prior.loc.new_tensor(parameter) - prior.loc)**2).sum().item()) self.assertAlmostEqual(prior.log_prob( prior.loc.new_tensor(parameter)).item(), expected_log_prob, places=5)