示例#1
0
    def _test_orthogonality(self, cls, cls_tall):
        r"""Test that we may instantiate the parametrizations and
        register them in modules of several sizes. Check that the
        results are orthogonal and equal in the three cases.
        """
        with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
            torch.random.manual_seed(8888)
            for layers in self._test_layers(cls, cls_tall):
                # Check that the initialization of the layers is orthogonal
                for layer in layers:
                    layer.parametrizations.weight.uniform_init_()
                    self.assertIsOrthogonal(layer.weight)
                    self.assertIsOrthogonal(layer.parametrizations.weight.base)

                # Make the initialization the same
                X = (
                    layers[0].weight.t()
                    if layers[0].parametrizations.weight.transpose
                    else layers[0].weight
                )
                for layer in layers[1:]:
                    with torch.no_grad():
                        layer.parametrizations.weight.base.copy_(X)
                    self.assertAlmostEqual(
                        torch.norm(layers[0].weight - layer.weight).item(),
                        0.0,
                        places=5,
                    )
                    self.assertIsOrthogonal(layer.parametrizations.weight.base)

                if isinstance(layers[0], nn.Linear):
                    input_ = torch.rand(5, layers[0].in_features)
                elif isinstance(layers[0], nn.Conv2d):
                    # batch x in_channel x in_length x in_width
                    input_ = torch.rand(6, layers[0].in_channels, 9, 8)

                results = []
                for layer in layers:
                    print(layer)
                    # Take one SGD step
                    optim = torch.optim.SGD(layer.parameters(), lr=0.1)
                    results.append([])

                    for _ in range(2):
                        with P.cached():
                            self.assertIsOrthogonal(layer.weight)
                            loss = layer(input_).sum()
                        optim.zero_grad()
                        loss.backward()
                        optim.step()
                        results[-1].append(layer.weight)
                    # If we change the base, the forward pass should give the same
                    prev_out = layer(input_)
                    layer.parametrizations.weight.update_base()
                    new_out = layer(input_)
                    self.assertAlmostEqual(
                        torch.norm(prev_out - new_out).abs().max().item(), 0.0, places=3
                    )

                self.assertPairwiseEqual(results)
示例#2
0
    def test_backprop(self):
        r"""Test that we may instantiate the parametrizations and
        register them in modules of several sizes. Check that the
        results are on the sphere
        """
        sizes = [1, 2, 3, 8]

        for n, lower in itertools.product(sizes, [True, False]):
            layer = nn.Linear(n, n)
            P.register_parametrization(
                layer, "weight",
                Symmetric(size=layer.weight.size(), lower=lower))

            input_ = torch.rand(5, n)
            optim = torch.optim.SGD(layer.parameters(), lr=1.0)

            # Assert that is stays in Sym(n) after some optimiser steps
            for i in range(2):
                print(i)
                with P.cached():
                    self.assertIsSymmetric(layer.weight)
                    loss = layer(input_).sum()
                optim.zero_grad()
                loss.backward()
                optim.step()
示例#3
0
    def test_backprop(self):
        r"""Test that we may instantiate the parametrizations and
        register them in modules of several sizes. Check that the
        results are on the sphere
        """
        sizes = [1, 2, 3, 4, 7, 8]

        with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
            torch.random.manual_seed(8888)
            for n in sizes:
                for cls in [Sphere, SphereEmbedded]:
                    layer = nn.Linear(n, 4)
                    P.register_parametrization(layer, "bias",
                                               cls(size=layer.bias.size()))
                    P.register_parametrization(layer, "weight",
                                               cls(size=layer.weight.size()))

                    with torch.no_grad():
                        layer.parametrizations.weight.uniform_init_()
                        layer.parametrizations.bias.uniform_init_()
                        self.assertInSn(layer.weight)
                        self.assertInSn(layer.bias)

                    input_ = torch.rand(5, n)
                    optim = torch.optim.SGD(layer.parameters(), lr=1.0)

                    # Assert that is stays in S^n after some optimiser steps
                    with torch.autograd.set_detect_anomaly(True):
                        for i in range(2):
                            print(i)
                            with P.cached():
                                self.assertInSn(layer.weight)
                                self.assertInSn(layer.bias)
                                loss = layer(input_).sum()
                            optim.zero_grad()
                            loss.backward()
                            optim.step()

                    # If we change the base, the forward pass should give the same
                    # SphereEmbedded does not have a base
                    if cls != SphereEmbedded:
                        for w in ["weight", "bias"]:
                            with torch.no_grad():
                                out_old = layer(input_)
                                getattr(layer.parametrizations,
                                        w).update_base()
                                out_new = layer(input_)
                                self.assertAlmostEqual(
                                    (out_old - out_new).abs().max().item(),
                                    0.0,
                                    places=5,
                                )
示例#4
0
    def _test_training(self, layer, args_sample, input_, initialize):
        msg = f"{layer}\n{args_sample}"
        M = layer.parametrizations.weight[0]
        if initialize:
            initial_size = layer.weight.size()
            X = M.sample(**args_sample)
            self.assertTrue(M.in_manifold(X), msg=msg)
            layer.weight = X
            with P.cached():
                # Compute the product if it is factorized
                X_matrix = self.matrix_from_factor(X, M).to(layer.weight.device)
                # The sampled matrix should not have a gradient
                self.assertFalse(X_matrix.requires_grad)
                # Size does not change
                self.assertEqual(initial_size, layer.weight.size(), msg=msg)
                # Tha initialisation initialisation is equal to what we passed
                self.assertTrue(
                    torch.allclose(layer.weight, X_matrix, atol=1e-5), msg=msg
                )

        # Take a couple SGD steps
        optim = torch.optim.SGD(layer.parameters(), lr=1e-3)
        for i in range(3):
            with P.cached():
                loss = layer(input_).mean()
            optim.zero_grad()
            loss.backward()
            optim.step()
            # The layer stays in the manifold while being optimised
            self.assertTrue(M.in_manifold(layer.weight), msg=f"i:{i}\n" + msg)

        with P.cached():
            weight_old = layer.weight
            update_base(layer, "weight")
            # After changing the base, the weight stays the same
            self.assertTrue(
                torch.allclose(layer.weight, weight_old, atol=1e-6), msg=msg
            )
示例#5
0
    def _test_custom_trivialization(self, cls):
        def qr(X):
            return torch.qr(X).Q

        # Note that qr is not an analytic function. As such, it may not be used with StiefelTall
        layer = nn.Linear(5, 3)
        P.register_parametrization(layer, "weight",
                                   cls(size=layer.weight.size(), triv=qr))

        optim = torch.optim.SGD(layer.parameters(), lr=0.1)
        input_ = torch.rand(5, layer.in_features)
        for _ in range(2):
            with P.cached():
                self.assertIsOrthogonal(layer.weight)
                loss = layer(input_).sum()
            optim.zero_grad()
            loss.backward()
            optim.step()
示例#6
0
    def _test_custom_trivialization(self, cls):
        def cayley(X):
            n = X.size(0)
            Id = torch.eye(n, dtype=X.dtype, device=X.device)
            return torch.solve(Id - X, Id + X)[0]

        layer = nn.Linear(5, 3)
        P.register_parametrization(layer, "weight",
                                   cls(size=layer.weight.size(), triv=cayley))

        optim = torch.optim.SGD(layer.parameters(), lr=0.1)
        input_ = torch.rand(5, layer.in_features)
        for _ in range(2):
            with P.cached():
                self.assertIsOrthogonal(layer.weight)
                loss = layer(input_).sum()
            optim.zero_grad()
            loss.backward()
            optim.step()