def _test_constructor(self, cls, cls_tall): with self.assertRaises(ValueError): cls(size=(3, 3), triv="wrong") with self.assertRaises(ValueError): cls_tall(size=(3, 3), triv="wrong") with self.assertRaises(ValueError): SO(size=(3, 3), triv="wrong") try: cls(size=(3, 3), triv=lambda: 3) except ValueError: self.fail("{} raised ValueError unexpectedly!".format(cls)) try: cls_tall(size=(3, 3), triv=lambda: 3) except ValueError: self.fail("{} raised ValueError unexpectedly!".format(cls_tall)) # Try to instantiate it in a vector rather than a matrix with self.assertRaises(ValueError): cls(size=(7,)) with self.assertRaises(ValueError): cls_tall(size=(7,)) with self.assertRaises(ValueError): SO(size=(7,))
def test_errors(self): # Pass something that is not a manifold raises with self.assertRaises(TypeError): constructions.FiberedSpace(dimensions=2, size=(2, 4), total_space=None) # update_base before registering it should throw M = SO((3, 3)) with self.assertRaises(ValueError): M.update_base() # Not passing the dimensions raises with self.assertRaises(ValueError): constructions.AbstractManifold(dimensions=None, size=(2, 4)) # Pasing a negative number raises with self.assertRaises(ValueError): constructions.AbstractManifold(dimensions=-1, size=(2, 4)) # Passing zero should raise with self.assertRaises(ValueError): constructions.AbstractManifold(dimensions=0, size=(2, 4)) # Pass a non-sequence value raises with self.assertRaises(ValueError): constructions.AbstractManifold(2, size=2)
def test_product_manifold(self): # Should not throw SO3SO3 = ProductManifold([SO((3, 3)), SO((3, 3))]) # A tuple should work as well SO3SO3 = ProductManifold((SO((3, 3)), SO((3, 3)))) # Forward should work X = (torch.rand(3, 3), torch.rand(3, 3)) Y1, Y2 = SO3SO3(X)
def _test_layers(self, cls, cls_tall): sizes = [ (8, 1), (8, 3), (8, 4), (8, 8), (7, 1), (7, 3), (7, 4), (7, 7), (1, 7), (2, 7), (1, 1), (1, 2), ] trivs = ["expm"] for (n, k), triv in itertools.product(sizes, trivs): for layer in [nn.Linear(n, k), nn.Conv2d(n, 4, k)]: layers = [] test_so = cls != Grassmannian and n == k layers.append(layer) layers.append(deepcopy(layer)) if test_so: layers.append(deepcopy(layer)) P.register_parametrization( layers[2], "weight", SO(size=layers[2].weight.size(), triv=triv) ) elif n != k: # If it's not square it should throw with self.assertRaises(ValueError): size = layer.weight.size()[:-2] + (n, k) SO(size=size, triv=triv) P.register_parametrization( layers[0], "weight", cls(size=layers[0].weight.size(), triv=triv) ) P.register_parametrization( layers[1], "weight", cls_tall(size=layers[1].weight.size(), triv=triv), ) yield layers
def test_product_manifold(self): SO3SO3 = constructions.ProductManifold([SO((3, 3)), SO((3, 3))]) # Len self.assertEqual(len(SO3SO3), 2) # Dir print(dir(SO3SO3)) # Get item self.assertTrue(isinstance(SO3SO3[0], SO)) # Iter for M in SO3SO3: self.assertTrue(isinstance(M, SO)) # repr print(SO3SO3) with self.assertRaises(ValueError): SO3SO3.update_base() # Pass something that is not a manifold raises with self.assertRaises(TypeError): SO3SO3 = constructions.ProductManifold([SO((3, 3)), 3])
def _test_constructor(self, cls): with self.assertRaises(ValueError): cls(size=(3, 3), triv="wrong") with self.assertRaises(ValueError): SO(size=(3, 3), triv="wrong") # Try a custom trivialization (it should break in the forward) cls(size=(3, 3), triv=lambda: 3) # Try to instantiate it in a vector rather than a matrix with self.assertRaises(VectorError): cls(size=(7, )) with self.assertRaises(VectorError): SO(size=(7, )) # Try to instantiate it in an on-square matrix with self.assertRaises(NonSquareError): SO(size=(7, 3, 2))