Ejemplo n.º 1
0
    def test_normal_prior_batch_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")

        mean = torch.tensor([0.0, 1.0], device=device)
        variance = torch.tensor([1.0, 2.0], device=device)
        prior = NormalPrior(mean, variance)
        dist = Normal(mean, variance)
        t = torch.zeros(2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.zeros(2, 2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.zeros(3, device=device))

        mean = torch.tensor([[0.0, 1.0], [-1.0, 2.0]], device=device)
        variance = torch.tensor([[1.0, 2.0], [0.5, 1.0]], device=device)
        prior = NormalPrior(mean, variance)
        dist = Normal(mean, variance)
        t = torch.zeros(2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.zeros(2, 2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.zeros(3, device=device))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.zeros(2, 3, device=device))
Ejemplo n.º 2
0
    def test_normal_prior_log_prob_log_transform(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        mean = torch.tensor(0.0, device=device)
        variance = torch.tensor(1.0, device=device)
        prior = NormalPrior(mean, variance, transform=torch.exp)
        dist = Normal(mean, variance)

        t = torch.tensor(0.0, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
        t = torch.tensor([-1, 0.5], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
        t = torch.tensor([[-1, 0.5], [0.1, -2.0]], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
Ejemplo n.º 3
0
 def test_scalar_normal_prior_log_transform(self):
     prior = NormalPrior(0, 1, log_transform=True)
     self.assertTrue(prior.log_transform)
     self.assertAlmostEqual(prior.log_prob(prior.loc.new([0.0])).item(),
                            math.log(1 / math.sqrt(2 * math.pi) *
                                     math.exp(-0.5)),
                            places=5)
Ejemplo n.º 4
0
 def test_scalar_normal_prior(self):
     prior = NormalPrior(0, 1)
     self.assertFalse(prior.log_transform)
     self.assertTrue(prior.is_in_support(torch.rand(1)))
     self.assertEqual(prior.shape, torch.Size([1]))
     self.assertEqual(prior.loc.item(), 0.0)
     self.assertEqual(prior.scale.item(), 1.0)
     self.assertAlmostEqual(prior.log_prob(prior.loc.new([0.0])).item(),
                            math.log(1 / math.sqrt(2 * math.pi)),
                            places=5)
Ejemplo n.º 5
0
 def test_vector_normal_prior_size(self):
     prior = NormalPrior(0, 1, size=2)
     self.assertFalse(prior.log_transform)
     self.assertTrue(prior.is_in_support(torch.zeros(1)))
     self.assertEqual(prior.shape, torch.Size([2]))
     self.assertTrue(torch.equal(prior.loc, torch.tensor([0.0, 0.0])))
     self.assertTrue(torch.equal(prior.scale, torch.tensor([1.0, 1.0])))
     parameter = torch.tensor([1.0, 2.0])
     self.assertAlmostEqual(
         prior.log_prob(parameter).item(),
         2 * math.log(1 / math.sqrt(2 * math.pi)) - 0.5 * (parameter ** 2).sum().item(),
         places=5,
     )
Ejemplo n.º 6
0
 def test_vector_normal_prior(self):
     prior = NormalPrior(torch.tensor([-0.5, 0.5]), torch.tensor([0.5,
                                                                  1.0]))
     self.assertFalse(prior.log_transform)
     self.assertTrue(prior.is_in_support(torch.rand(1)))
     self.assertEqual(prior.shape, torch.Size([2]))
     self.assertTrue(torch.equal(prior.loc, prior.loc.new([-0.5, 0.5])))
     self.assertTrue(torch.equal(prior.scale, prior.scale.new([0.5, 1.0])))
     parameter = prior.loc.new([1.0, 2.0])
     expected_log_prob = (
         ((1 / math.sqrt(2 * math.pi) / prior.scale).log() -
          0.5 / prior.scale**2 *
          (prior.loc.new_tensor(parameter) - prior.loc)**2).sum().item())
     self.assertAlmostEqual(prior.log_prob(
         prior.loc.new_tensor(parameter)).item(),
                            expected_log_prob,
                            places=5)