def test_laplace_shape_tensor_params(self): laplace = Laplace(torch.Tensor([0, 0]), torch.Tensor([1, 1])) self.assertEqual(laplace._batch_shape, torch.Size((2,))) self.assertEqual(laplace._event_shape, torch.Size(())) self.assertEqual(laplace.sample().size(), torch.Size((2,))) self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2))) self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
def test_laplace_shape_scalar_params(self): laplace = Laplace(0, 1) self.assertEqual(laplace._batch_shape, torch.Size()) self.assertEqual(laplace._event_shape, torch.Size()) self.assertEqual(laplace.sample().size(), torch.Size((1,))) self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2))) self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample) self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))