Example #1
0
 def test_raises_domain_exception(self):
     shape = [2, 3, 4]
     transform = nl.Tanh()
     for value in [-2.0, -1.0, 1.0, 2.0]:
         with self.assertRaises(InputOutsideDomain):
             inputs = torch.full(shape, value)
             transform.inverse(inputs)
 def test_forward_inverse_are_consistent(self):
     batch_size = 10
     shape = [5, 10, 15]
     inputs = torch.rand(batch_size, *shape)
     transforms = [
         nl.Tanh(),
         nl.LogTanh(),
         nl.LeakyReLU(),
         nl.Sigmoid(),
         nl.Logit(),
         nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()),
     ]
     self.eps = 1e-3
     for transform in transforms:
         with self.subTest(transform=transform):
             self.assert_forward_inverse_are_consistent(transform, inputs)
 def test_inverse(self):
     batch_size = 10
     shape = [5, 10, 15]
     inputs = torch.rand(batch_size, *shape)
     transforms = [
         nl.Tanh(),
         nl.LogTanh(),
         nl.LeakyReLU(),
         nl.Sigmoid(),
         nl.Logit(),
         nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()),
     ]
     for transform in transforms:
         with self.subTest(transform=transform):
             outputs, logabsdet = transform.inverse(inputs)
             self.assert_tensor_is_good(outputs, [batch_size] + shape)
             self.assert_tensor_is_good(logabsdet, [batch_size])