Exemplo n.º 1
0
    def _test_constructor(self, cls, cls_tall):
        with self.assertRaises(ValueError):
            cls(size=(3, 3), triv="wrong")

        with self.assertRaises(ValueError):
            cls_tall(size=(3, 3), triv="wrong")

        with self.assertRaises(ValueError):
            SO(size=(3, 3), triv="wrong")

        try:
            cls(size=(3, 3), triv=lambda: 3)
        except ValueError:
            self.fail("{} raised ValueError unexpectedly!".format(cls))

        try:
            cls_tall(size=(3, 3), triv=lambda: 3)
        except ValueError:
            self.fail("{} raised ValueError unexpectedly!".format(cls_tall))

        # Try to instantiate it in a vector rather than a matrix
        with self.assertRaises(ValueError):
            cls(size=(7,))

        with self.assertRaises(ValueError):
            cls_tall(size=(7,))

        with self.assertRaises(ValueError):
            SO(size=(7,))
Exemplo n.º 2
0
    def test_errors(self):
        # Pass something that is not a manifold raises
        with self.assertRaises(TypeError):
            constructions.FiberedSpace(dimensions=2,
                                       size=(2, 4),
                                       total_space=None)

        # update_base before registering it should throw
        M = SO((3, 3))
        with self.assertRaises(ValueError):
            M.update_base()

        # Not passing the dimensions raises
        with self.assertRaises(ValueError):
            constructions.AbstractManifold(dimensions=None, size=(2, 4))

        # Pasing a negative number raises
        with self.assertRaises(ValueError):
            constructions.AbstractManifold(dimensions=-1, size=(2, 4))

        # Passing zero should raise
        with self.assertRaises(ValueError):
            constructions.AbstractManifold(dimensions=0, size=(2, 4))

        # Pass a non-sequence value raises
        with self.assertRaises(ValueError):
            constructions.AbstractManifold(2, size=2)
Exemplo n.º 3
0
    def test_product_manifold(self):
        # Should not throw
        SO3SO3 = ProductManifold([SO((3, 3)), SO((3, 3))])

        # A tuple should work as well
        SO3SO3 = ProductManifold((SO((3, 3)), SO((3, 3))))

        # Forward should work
        X = (torch.rand(3, 3), torch.rand(3, 3))
        Y1, Y2 = SO3SO3(X)
Exemplo n.º 4
0
    def _test_layers(self, cls, cls_tall):
        sizes = [
            (8, 1),
            (8, 3),
            (8, 4),
            (8, 8),
            (7, 1),
            (7, 3),
            (7, 4),
            (7, 7),
            (1, 7),
            (2, 7),
            (1, 1),
            (1, 2),
        ]
        trivs = ["expm"]

        for (n, k), triv in itertools.product(sizes, trivs):
            for layer in [nn.Linear(n, k), nn.Conv2d(n, 4, k)]:
                layers = []
                test_so = cls != Grassmannian and n == k
                layers.append(layer)
                layers.append(deepcopy(layer))
                if test_so:
                    layers.append(deepcopy(layer))
                    P.register_parametrization(
                        layers[2], "weight", SO(size=layers[2].weight.size(), triv=triv)
                    )
                elif n != k:
                    # If it's not square it should throw
                    with self.assertRaises(ValueError):
                        size = layer.weight.size()[:-2] + (n, k)
                        SO(size=size, triv=triv)

                P.register_parametrization(
                    layers[0], "weight", cls(size=layers[0].weight.size(), triv=triv)
                )
                P.register_parametrization(
                    layers[1],
                    "weight",
                    cls_tall(size=layers[1].weight.size(), triv=triv),
                )
                yield layers
Exemplo n.º 5
0
    def test_product_manifold(self):
        SO3SO3 = constructions.ProductManifold([SO((3, 3)), SO((3, 3))])
        # Len
        self.assertEqual(len(SO3SO3), 2)
        # Dir
        print(dir(SO3SO3))
        # Get item
        self.assertTrue(isinstance(SO3SO3[0], SO))
        # Iter
        for M in SO3SO3:
            self.assertTrue(isinstance(M, SO))
        # repr
        print(SO3SO3)
        with self.assertRaises(ValueError):
            SO3SO3.update_base()

        # Pass something that is not a manifold raises
        with self.assertRaises(TypeError):
            SO3SO3 = constructions.ProductManifold([SO((3, 3)), 3])
Exemplo n.º 6
0
    def _test_constructor(self, cls):
        with self.assertRaises(ValueError):
            cls(size=(3, 3), triv="wrong")

        with self.assertRaises(ValueError):
            SO(size=(3, 3), triv="wrong")

        # Try a custom trivialization (it should break in the forward)
        cls(size=(3, 3), triv=lambda: 3)

        # Try to instantiate it in a vector rather than a matrix
        with self.assertRaises(VectorError):
            cls(size=(7, ))

        with self.assertRaises(VectorError):
            SO(size=(7, ))

        # Try to instantiate it in an on-square matrix
        with self.assertRaises(NonSquareError):
            SO(size=(7, 3, 2))