Пример #1
0
 def test_laplace_shape_tensor_params(self):
     laplace = Laplace(torch.Tensor([0, 0]), torch.Tensor([1, 1]))
     self.assertEqual(laplace._batch_shape, torch.Size((2,)))
     self.assertEqual(laplace._event_shape, torch.Size(()))
     self.assertEqual(laplace.sample().size(), torch.Size((2,)))
     self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
     self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
Пример #2
0
 def test_laplace_shape_scalar_params(self):
     laplace = Laplace(0, 1)
     self.assertEqual(laplace._batch_shape, torch.Size())
     self.assertEqual(laplace._event_shape, torch.Size())
     self.assertEqual(laplace.sample().size(), torch.Size((1,)))
     self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample)
     self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertEqual(laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))