Ejemplo n.º 1
0
    def assert_forward_inverse_are_consistent(self, transform, inputs):
        inverse = base.InverseTransform(transform)
        identity = base.CompositeTransform([inverse, transform])
        outputs, logabsdet = identity(inputs)

        self.assert_tensor_is_good(outputs, shape=inputs.shape)
        self.assert_tensor_is_good(logabsdet, shape=inputs.shape[:1])
        self.assertEqual(outputs, inputs)
        self.assertEqual(logabsdet, torch.zeros(inputs.shape[:1]))
Ejemplo n.º 2
0
 def test_inverse(self):
     batch_size = 10
     shape = [2, 3, 4]
     inputs = torch.randn(batch_size, *shape)
     transforms = [
         standard.AffineScalarTransform(scale=2.0),
         standard.IdentityTransform(),
         standard.AffineScalarTransform(scale=-0.25),
     ]
     composite = base.CompositeTransform(transforms)
     reference = standard.AffineScalarTransform(scale=-0.5)
     outputs, logabsdet = composite.inverse(inputs)
     outputs_ref, logabsdet_ref = reference.inverse(inputs)
     self.assert_tensor_is_good(outputs, [batch_size] + shape)
     self.assert_tensor_is_good(logabsdet, [batch_size])
     self.assertEqual(outputs, outputs_ref)
     self.assertEqual(logabsdet, logabsdet_ref)