Ejemplo n.º 1
0
 def setUp(self):
     self.shape = [2, 3, 4]
     self.batch_size = 10
     self.transforms = [
         nl.PiecewiseLinearCDF(self.shape),
         nl.PiecewiseQuadraticCDF(self.shape),
         nl.PiecewiseCubicCDF(self.shape),
         nl.PiecewiseRationalQuadraticCDF(self.shape),
     ]
Ejemplo n.º 2
0
    def test_forward_inverse_are_consistent(self):
        shape = [2, 3, 4]
        batch_size = 10
        transforms = [
            nl.PiecewiseLinearCDF(shape, tails="linear"),
            nl.PiecewiseQuadraticCDF(shape, tails="linear"),
            nl.PiecewiseCubicCDF(shape, tails="linear"),
            nl.PiecewiseRationalQuadraticCDF(shape, tails="linear"),
        ]

        for transform in transforms:
            with self.subTest(transform=transform):
                inputs = 3 * torch.randn(batch_size, *shape)
                self.eps = 1e-4
                self.assert_forward_inverse_are_consistent(transform, inputs)