def test_inverse(self): batch_size = 10 shape = [2, 3, 4] inputs = torch.randn(batch_size, *shape) transform = standard.IdentityTransform() outputs, logabsdet = transform.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size] + shape) self.assert_tensor_is_good(logabsdet, [batch_size]) self.assertEqual(outputs, inputs) self.assertEqual(logabsdet, torch.zeros(batch_size))
def test_inverse(self): batch_size = 10 shape = [2, 3, 4] inputs = torch.randn(batch_size, *shape) transforms = [ standard.AffineScalarTransform(scale=2.0), standard.IdentityTransform(), standard.AffineScalarTransform(scale=-0.25), ] composite = base.CompositeTransform(transforms) reference = standard.AffineScalarTransform(scale=-0.5) outputs, logabsdet = composite.inverse(inputs) outputs_ref, logabsdet_ref = reference.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size] + shape) self.assert_tensor_is_good(logabsdet, [batch_size]) self.assertEqual(outputs, outputs_ref) self.assertEqual(logabsdet, logabsdet_ref)
def test_forward_inverse_are_consistent(self): batch_size = 10 shape = [5, 10, 15] inputs = torch.rand(batch_size, *shape) transforms = [ nl.Tanh(), nl.LogTanh(), nl.LeakyReLU(), nl.Sigmoid(), nl.Logit(), nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()) ] self.eps = 1e-3 for transform in transforms: with self.subTest(transform=transform): self.assert_forward_inverse_are_consistent(transform, inputs)
def test_forward(self): batch_size = 10 shape = [5, 10, 15] inputs = torch.rand(batch_size, *shape) transforms = [ nl.Tanh(), nl.LogTanh(), nl.LeakyReLU(), nl.Sigmoid(), nl.Logit(), nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()) ] for transform in transforms: with self.subTest(transform=transform): outputs, logabsdet = transform(inputs) self.assert_tensor_is_good(outputs, [batch_size] + shape) self.assert_tensor_is_good(logabsdet, [batch_size])
def test_forward_inverse_are_consistent(self): batch_size = 10 shape = [2, 3, 4] inputs = torch.randn(batch_size, *shape) transform = standard.IdentityTransform() self.assert_forward_inverse_are_consistent(transform, inputs)