def _test_orthogonality(self, cls, cls_tall): r"""Test that we may instantiate the parametrizations and register them in modules of several sizes. Check that the results are orthogonal and equal in the three cases. """ with torch.random.fork_rng(devices=range(torch.cuda.device_count())): torch.random.manual_seed(8888) for layers in self._test_layers(cls, cls_tall): # Check that the initialization of the layers is orthogonal for layer in layers: layer.parametrizations.weight.uniform_init_() self.assertIsOrthogonal(layer.weight) self.assertIsOrthogonal(layer.parametrizations.weight.base) # Make the initialization the same X = ( layers[0].weight.t() if layers[0].parametrizations.weight.transpose else layers[0].weight ) for layer in layers[1:]: with torch.no_grad(): layer.parametrizations.weight.base.copy_(X) self.assertAlmostEqual( torch.norm(layers[0].weight - layer.weight).item(), 0.0, places=5, ) self.assertIsOrthogonal(layer.parametrizations.weight.base) if isinstance(layers[0], nn.Linear): input_ = torch.rand(5, layers[0].in_features) elif isinstance(layers[0], nn.Conv2d): # batch x in_channel x in_length x in_width input_ = torch.rand(6, layers[0].in_channels, 9, 8) results = [] for layer in layers: print(layer) # Take one SGD step optim = torch.optim.SGD(layer.parameters(), lr=0.1) results.append([]) for _ in range(2): with P.cached(): self.assertIsOrthogonal(layer.weight) loss = layer(input_).sum() optim.zero_grad() loss.backward() optim.step() results[-1].append(layer.weight) # If we change the base, the forward pass should give the same prev_out = layer(input_) layer.parametrizations.weight.update_base() new_out = layer(input_) self.assertAlmostEqual( torch.norm(prev_out - new_out).abs().max().item(), 0.0, places=3 ) self.assertPairwiseEqual(results)
def test_backprop(self): r"""Test that we may instantiate the parametrizations and register them in modules of several sizes. Check that the results are on the sphere """ sizes = [1, 2, 3, 8] for n, lower in itertools.product(sizes, [True, False]): layer = nn.Linear(n, n) P.register_parametrization( layer, "weight", Symmetric(size=layer.weight.size(), lower=lower)) input_ = torch.rand(5, n) optim = torch.optim.SGD(layer.parameters(), lr=1.0) # Assert that is stays in Sym(n) after some optimiser steps for i in range(2): print(i) with P.cached(): self.assertIsSymmetric(layer.weight) loss = layer(input_).sum() optim.zero_grad() loss.backward() optim.step()
def test_backprop(self): r"""Test that we may instantiate the parametrizations and register them in modules of several sizes. Check that the results are on the sphere """ sizes = [1, 2, 3, 4, 7, 8] with torch.random.fork_rng(devices=range(torch.cuda.device_count())): torch.random.manual_seed(8888) for n in sizes: for cls in [Sphere, SphereEmbedded]: layer = nn.Linear(n, 4) P.register_parametrization(layer, "bias", cls(size=layer.bias.size())) P.register_parametrization(layer, "weight", cls(size=layer.weight.size())) with torch.no_grad(): layer.parametrizations.weight.uniform_init_() layer.parametrizations.bias.uniform_init_() self.assertInSn(layer.weight) self.assertInSn(layer.bias) input_ = torch.rand(5, n) optim = torch.optim.SGD(layer.parameters(), lr=1.0) # Assert that is stays in S^n after some optimiser steps with torch.autograd.set_detect_anomaly(True): for i in range(2): print(i) with P.cached(): self.assertInSn(layer.weight) self.assertInSn(layer.bias) loss = layer(input_).sum() optim.zero_grad() loss.backward() optim.step() # If we change the base, the forward pass should give the same # SphereEmbedded does not have a base if cls != SphereEmbedded: for w in ["weight", "bias"]: with torch.no_grad(): out_old = layer(input_) getattr(layer.parametrizations, w).update_base() out_new = layer(input_) self.assertAlmostEqual( (out_old - out_new).abs().max().item(), 0.0, places=5, )
def _test_training(self, layer, args_sample, input_, initialize): msg = f"{layer}\n{args_sample}" M = layer.parametrizations.weight[0] if initialize: initial_size = layer.weight.size() X = M.sample(**args_sample) self.assertTrue(M.in_manifold(X), msg=msg) layer.weight = X with P.cached(): # Compute the product if it is factorized X_matrix = self.matrix_from_factor(X, M).to(layer.weight.device) # The sampled matrix should not have a gradient self.assertFalse(X_matrix.requires_grad) # Size does not change self.assertEqual(initial_size, layer.weight.size(), msg=msg) # Tha initialisation initialisation is equal to what we passed self.assertTrue( torch.allclose(layer.weight, X_matrix, atol=1e-5), msg=msg ) # Take a couple SGD steps optim = torch.optim.SGD(layer.parameters(), lr=1e-3) for i in range(3): with P.cached(): loss = layer(input_).mean() optim.zero_grad() loss.backward() optim.step() # The layer stays in the manifold while being optimised self.assertTrue(M.in_manifold(layer.weight), msg=f"i:{i}\n" + msg) with P.cached(): weight_old = layer.weight update_base(layer, "weight") # After changing the base, the weight stays the same self.assertTrue( torch.allclose(layer.weight, weight_old, atol=1e-6), msg=msg )
def _test_custom_trivialization(self, cls): def qr(X): return torch.qr(X).Q # Note that qr is not an analytic function. As such, it may not be used with StiefelTall layer = nn.Linear(5, 3) P.register_parametrization(layer, "weight", cls(size=layer.weight.size(), triv=qr)) optim = torch.optim.SGD(layer.parameters(), lr=0.1) input_ = torch.rand(5, layer.in_features) for _ in range(2): with P.cached(): self.assertIsOrthogonal(layer.weight) loss = layer(input_).sum() optim.zero_grad() loss.backward() optim.step()
def _test_custom_trivialization(self, cls): def cayley(X): n = X.size(0) Id = torch.eye(n, dtype=X.dtype, device=X.device) return torch.solve(Id - X, Id + X)[0] layer = nn.Linear(5, 3) P.register_parametrization(layer, "weight", cls(size=layer.weight.size(), triv=cayley)) optim = torch.optim.SGD(layer.parameters(), lr=0.1) input_ = torch.rand(5, layer.in_features) for _ in range(2): with P.cached(): self.assertIsOrthogonal(layer.weight) loss = layer(input_).sum() optim.zero_grad() loss.backward() optim.step()