Example #1
0
 def logabsdet(self):
     """Cost:
         logabsdet = O(D^3)
     where:
         D = num of features
     """
     return utils.logabsdet(self._weight)
Example #2
0
    def setUp(self):
        self.features = 3
        self.transform = linear.NaiveLinear(features=self.features)

        self.weight = self.transform._weight
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = utils.logabsdet(self.weight)

        self.eps = 1e-5
Example #3
0
    def setUp(self):
        self.features = 3
        self.transform = lu.LULinear(features=self.features)

        lower, upper = self.transform._create_lower_upper()
        self.weight = lower @ upper
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = utils.logabsdet(self.weight)

        self.eps = 1e-5
Example #4
0
    def setUp(self):
        self.features = 3
        self.transform = qr.QRLinear(features=self.features, num_householder=4)

        upper = self.transform._create_upper()
        orthogonal = self.transform.orthogonal.matrix()
        self.weight = orthogonal @ upper
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = utils.logabsdet(self.weight)

        self.eps = 1e-5
Example #5
0
    def setUp(self):
        self.features = 3
        self.transform = svd.SVDLinear(features=self.features, num_householder=4)
        self.transform.bias.data = torch.randn(self.features)  # Just so bias isn't zero.

        diagonal = torch.diag(torch.exp(self.transform.log_diagonal))
        orthogonal_1 = self.transform.orthogonal_1.matrix()
        orthogonal_2 = self.transform.orthogonal_2.matrix()
        self.weight = orthogonal_1 @ diagonal @ orthogonal_2
        self.weight_inverse = torch.inverse(self.weight)
        self.logabsdet = utils.logabsdet(self.weight)

        self.eps = 1e-5
Example #6
0
 def forward_no_cache(self, inputs):
     """Cost:
         output = O(D^2N)
         logabsdet = O(D^3)
     where:
         D = num of features
         N = num of inputs
     """
     batch_size = inputs.shape[0]
     outputs = F.linear(inputs, self._weight, self.bias)
     logabsdet = utils.logabsdet(self._weight)
     logabsdet = logabsdet * torch.ones(batch_size)
     return outputs, logabsdet
Example #7
0
    def test_inverse(self):
        features = 100
        batch_size = 50

        for num_transforms in [1, 2, 11, 12]:
            with self.subTest(num_transforms=num_transforms):
                transform = orthogonal.HouseholderSequence(
                    features=features, num_transforms=num_transforms)
                matrix = transform.matrix()
                inputs = torch.randn(batch_size, features)
                outputs, logabsdet = transform.inverse(inputs)
                self.assert_tensor_is_good(outputs, [batch_size, features])
                self.assert_tensor_is_good(logabsdet, [batch_size])
                self.eps = 1e-5
                self.assertEqual(outputs, inputs @ matrix)
                self.assertEqual(
                    logabsdet,
                    utils.logabsdet(matrix) * torch.ones(batch_size))