def logabsdet(self): """Cost: logabsdet = O(D^3) where: D = num of features """ return torchutils.logabsdet(self._weight)
def test_logabsdet(self): size = 10 matrix = torch.randn(size, size) logabsdet = torchutils.logabsdet(matrix) logabsdet_ref = torch.log(torch.abs(matrix.det())) self.eps = 1e-6 self.assertEqual(logabsdet, logabsdet_ref)
def setUp(self): self.features = 3 self.transform = linear.NaiveLinear(features=self.features) self.weight = self.transform._weight self.weight_inverse = torch.inverse(self.weight) self.logabsdet = torchutils.logabsdet(self.weight) self.eps = 1e-5
def setUp(self): self.features = 3 self.transform = lu.LULinear(features=self.features) lower, upper = self.transform._create_lower_upper() self.weight = lower @ upper self.weight_inverse = torch.inverse(self.weight) self.logabsdet = torchutils.logabsdet(self.weight) self.eps = 1e-5
def setUp(self): self.features = 3 self.transform = qr.QRLinear(features=self.features, num_householder=4) upper = self.transform._create_upper() orthogonal = self.transform.orthogonal.matrix() self.weight = orthogonal @ upper self.weight_inverse = torch.inverse(self.weight) self.logabsdet = torchutils.logabsdet(self.weight) self.eps = 1e-5
def forward_no_cache(self, inputs): """Cost: output = O(D^2N) logabsdet = O(D^3) where: D = num of features N = num of inputs """ batch_size = inputs.shape[0] outputs = F.linear(inputs, self._weight, self.bias) logabsdet = torchutils.logabsdet(self._weight) logabsdet = logabsdet * torch.ones(batch_size) return outputs, logabsdet
def setUp(self): self.features = 3 self.transform = SVDLinear(features=self.features, num_householder=4) self.transform.bias.data = torch.randn( self.features) # Just so bias isn't zero. diagonal = torch.diag(torch.exp(self.transform.log_diagonal)) orthogonal_1 = self.transform.orthogonal_1.matrix() orthogonal_2 = self.transform.orthogonal_2.matrix() self.weight = orthogonal_1 @ diagonal @ orthogonal_2 self.weight_inverse = torch.inverse(self.weight) self.logabsdet = torchutils.logabsdet(self.weight) self.eps = 1e-5
def test_inverse_full_orthogonal(self): features = 100 batch_size = 50 transform = orthogonal.FullOrthogonalTransform(features=features) matrix = transform.matrix() inputs = torch.randn(batch_size, features) outputs, logabsdet = transform.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size, features]) self.assert_tensor_is_good(logabsdet, [batch_size]) self.eps = 1e-5 self.assertEqual(outputs, inputs @ matrix) self.assertEqual(logabsdet, torchutils.logabsdet(matrix) * torch.ones(batch_size)) self.assert_forward_inverse_are_consistent(transform, inputs)
def test_inverse(self): features = 100 batch_size = 50 for num_transforms in [1, 2, 11, 12]: with self.subTest(num_transforms=num_transforms): transform = orthogonal.HouseholderSequence( features=features, num_transforms=num_transforms) matrix = transform.matrix() inputs = torch.randn(batch_size, features) outputs, logabsdet = transform.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size, features]) self.assert_tensor_is_good(logabsdet, [batch_size]) self.eps = 1e-5 self.assertEqual(outputs, inputs @ matrix) self.assertEqual( logabsdet, torchutils.logabsdet(matrix) * torch.ones(batch_size))