Пример #1
0
 def test_inverse(self):
     batch_size = 10
     shape = [2, 3, 4]
     inputs = torch.randn(batch_size, *shape)
     transform = standard.IdentityTransform()
     outputs, logabsdet = transform.inverse(inputs)
     self.assert_tensor_is_good(outputs, [batch_size] + shape)
     self.assert_tensor_is_good(logabsdet, [batch_size])
     self.assertEqual(outputs, inputs)
     self.assertEqual(logabsdet, torch.zeros(batch_size))
Пример #2
0
 def test_forward_inverse_are_consistent(self):
     batch_size = 10
     shape = [5, 10, 15]
     inputs = torch.rand(batch_size, *shape)
     transforms = [
         nl.Tanh(),
         nl.LogTanh(),
         nl.LeakyReLU(),
         nl.Sigmoid(),
         nl.Logit(),
         nl.CompositeCDFTransform(nl.Sigmoid(),
                                  standard.IdentityTransform()),
     ]
     self.eps = 1e-3
     for transform in transforms:
         with self.subTest(transform=transform):
             self.assert_forward_inverse_are_consistent(transform, inputs)
Пример #3
0
 def test_inverse(self):
     batch_size = 10
     shape = [2, 3, 4]
     inputs = torch.randn(batch_size, *shape)
     transforms = [
         standard.AffineScalarTransform(scale=2.0),
         standard.IdentityTransform(),
         standard.AffineScalarTransform(scale=-0.25),
     ]
     composite = base.CompositeTransform(transforms)
     reference = standard.AffineScalarTransform(scale=-0.5)
     outputs, logabsdet = composite.inverse(inputs)
     outputs_ref, logabsdet_ref = reference.inverse(inputs)
     self.assert_tensor_is_good(outputs, [batch_size] + shape)
     self.assert_tensor_is_good(logabsdet, [batch_size])
     self.assertEqual(outputs, outputs_ref)
     self.assertEqual(logabsdet, logabsdet_ref)
Пример #4
0
 def test_forward(self):
     batch_size = 10
     shape = [5, 10, 15]
     inputs = torch.rand(batch_size, *shape)
     transforms = [
         nl.Tanh(),
         nl.LogTanh(),
         nl.LeakyReLU(),
         nl.Sigmoid(),
         nl.Logit(),
         nl.CompositeCDFTransform(nl.Sigmoid(),
                                  standard.IdentityTransform()),
     ]
     for transform in transforms:
         with self.subTest(transform=transform):
             outputs, logabsdet = transform(inputs)
             self.assert_tensor_is_good(outputs, [batch_size] + shape)
             self.assert_tensor_is_good(logabsdet, [batch_size])
Пример #5
0
 def test_forward_inverse_are_consistent(self):
     batch_size = 10
     shape = [2, 3, 4]
     inputs = torch.randn(batch_size, *shape)
     transform = standard.IdentityTransform()
     self.assert_forward_inverse_are_consistent(transform, inputs)