def test_raises_domain_exception(self): shape = [2, 3, 4] transform = nl.Tanh() for value in [-2.0, -1.0, 1.0, 2.0]: with self.assertRaises(InputOutsideDomain): inputs = torch.full(shape, value) transform.inverse(inputs)
def test_forward_inverse_are_consistent(self): batch_size = 10 shape = [5, 10, 15] inputs = torch.rand(batch_size, *shape) transforms = [ nl.Tanh(), nl.LogTanh(), nl.LeakyReLU(), nl.Sigmoid(), nl.Logit(), nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()), ] self.eps = 1e-3 for transform in transforms: with self.subTest(transform=transform): self.assert_forward_inverse_are_consistent(transform, inputs)
def test_inverse(self): batch_size = 10 shape = [5, 10, 15] inputs = torch.rand(batch_size, *shape) transforms = [ nl.Tanh(), nl.LogTanh(), nl.LeakyReLU(), nl.Sigmoid(), nl.Logit(), nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()), ] for transform in transforms: with self.subTest(transform=transform): outputs, logabsdet = transform.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size] + shape) self.assert_tensor_is_good(logabsdet, [batch_size])