Example #1
0
 def logabsdet(self):
     """Cost:
         logabsdet = O(D^3)
     where:
         D = num of features
     """
     return torchutils.logabsdet(self._weight)
Example #2
0
 def test_logabsdet(self):
     size = 10
     matrix = torch.randn(size, size)
     logabsdet = torchutils.logabsdet(matrix)
     logabsdet_ref = torch.log(torch.abs(matrix.det()))
     self.eps = 1e-6
     self.assertEqual(logabsdet, logabsdet_ref)
Example #3
0
    def setUp(self):
        self.features = 3
        self.transform = linear.NaiveLinear(features=self.features)

        self.weight = self.transform._weight
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = torchutils.logabsdet(self.weight)

        self.eps = 1e-5
Example #4
0
    def setUp(self):
        self.features = 3
        self.transform = lu.LULinear(features=self.features)

        lower, upper = self.transform._create_lower_upper()
        self.weight = lower @ upper
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = torchutils.logabsdet(self.weight)

        self.eps = 1e-5
Example #5
0
    def setUp(self):
        self.features = 3
        self.transform = qr.QRLinear(features=self.features, num_householder=4)

        upper = self.transform._create_upper()
        orthogonal = self.transform.orthogonal.matrix()
        self.weight = orthogonal @ upper
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = torchutils.logabsdet(self.weight)

        self.eps = 1e-5
Example #6
0
 def forward_no_cache(self, inputs):
     """Cost:
         output = O(D^2N)
         logabsdet = O(D^3)
     where:
         D = num of features
         N = num of inputs
     """
     batch_size = inputs.shape[0]
     outputs = F.linear(inputs, self._weight, self.bias)
     logabsdet = torchutils.logabsdet(self._weight)
     logabsdet = logabsdet * torch.ones(batch_size)
     return outputs, logabsdet
Example #7
0
    def setUp(self):
        self.features = 3
        self.transform = SVDLinear(features=self.features, num_householder=4)
        self.transform.bias.data = torch.randn(
            self.features)  # Just so bias isn't zero.

        diagonal = torch.diag(torch.exp(self.transform.log_diagonal))
        orthogonal_1 = self.transform.orthogonal_1.matrix()
        orthogonal_2 = self.transform.orthogonal_2.matrix()
        self.weight = orthogonal_1 @ diagonal @ orthogonal_2
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = torchutils.logabsdet(self.weight)

        self.eps = 1e-5
Example #8
0
 def test_inverse_full_orthogonal(self):
     features = 100
     batch_size = 50
     transform = orthogonal.FullOrthogonalTransform(features=features)
     matrix = transform.matrix()
     inputs = torch.randn(batch_size, features)
     outputs, logabsdet = transform.inverse(inputs)
     self.assert_tensor_is_good(outputs, [batch_size, features])
     self.assert_tensor_is_good(logabsdet, [batch_size])
     self.eps = 1e-5
     self.assertEqual(outputs, inputs @ matrix)
     self.assertEqual(logabsdet,
                      torchutils.logabsdet(matrix) * torch.ones(batch_size))
     self.assert_forward_inverse_are_consistent(transform, inputs)
Example #9
0
    def test_inverse(self):
        features = 100
        batch_size = 50

        for num_transforms in [1, 2, 11, 12]:
            with self.subTest(num_transforms=num_transforms):
                transform = orthogonal.HouseholderSequence(
                    features=features, num_transforms=num_transforms)
                matrix = transform.matrix()
                inputs = torch.randn(batch_size, features)
                outputs, logabsdet = transform.inverse(inputs)
                self.assert_tensor_is_good(outputs, [batch_size, features])
                self.assert_tensor_is_good(logabsdet, [batch_size])
                self.eps = 1e-5
                self.assertEqual(outputs, inputs @ matrix)
                self.assertEqual(
                    logabsdet,
                    torchutils.logabsdet(matrix) * torch.ones(batch_size))