Example #1
0
 def test_forward_inverse_are_consistent(self):
     features = 100
     batch_size = 50
     inputs = torch.randn(batch_size, features)
     transforms = [
         orthogonal.HouseholderSequence(features=features,
                                        num_transforms=num_transforms)
         for num_transforms in [1, 2, 11, 12]
     ]
     self.eps = 1e-5
     for transform in transforms:
         with self.subTest(transform=transform):
             self.assert_forward_inverse_are_consistent(transform, inputs)
Example #2
0
    def test_matrix(self):
        features = 100

        for num_transforms in [1, 2, 11, 12]:
            with self.subTest(num_transforms=num_transforms):
                transform = orthogonal.HouseholderSequence(
                    features=features, num_transforms=num_transforms)
                matrix = transform.matrix()
                self.assert_tensor_is_good(matrix, [features, features])
                self.eps = 1e-5
                self.assertEqual(matrix @ matrix.t(),
                                 torch.eye(features, features))
                self.assertEqual(matrix.t() @ matrix,
                                 torch.eye(features, features))
                self.assertEqual(matrix.t(), torch.inverse(matrix))
                det_ref = torch.tensor(1.0 if num_transforms %
                                       2 == 0 else -1.0)
                self.assertEqual(matrix.det(), det_ref)
Example #3
0
    def test_inverse(self):
        features = 100
        batch_size = 50

        for num_transforms in [1, 2, 11, 12]:
            with self.subTest(num_transforms=num_transforms):
                transform = orthogonal.HouseholderSequence(
                    features=features, num_transforms=num_transforms)
                matrix = transform.matrix()
                inputs = torch.randn(batch_size, features)
                outputs, logabsdet = transform.inverse(inputs)
                self.assert_tensor_is_good(outputs, [batch_size, features])
                self.assert_tensor_is_good(logabsdet, [batch_size])
                self.eps = 1e-5
                self.assertEqual(outputs, inputs @ matrix)
                self.assertEqual(
                    logabsdet,
                    torchutils.logabsdet(matrix) * torch.ones(batch_size))