Exemplo n.º 1
0
    def test_sequential_model(self):
        model = nn.Sequential(nn.Conv2d(1, 20, 5), nn.ReLU(),
                              nn.Conv2d(20, 64, 5), nn.ReLU())

        attach_name_to_modules(model)

        for i in range(4):
            self.assertEqual(model[i]._extractor_fullname, str(i))
Exemplo n.º 2
0
    def test_sequential_model_with_dict(self):
        names = ["conv1", "relu1", "conv2", "relu2"]
        ops = [nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()]
        model = nn.Sequential(OrderedDict(zip(names, ops)))

        attach_name_to_modules(model)

        for i, name in enumerate(names):
            self.assertEqual(model[i]._extractor_fullname, name)
Exemplo n.º 3
0
    def test_module_not_found(self):
        names = ["conv1", "relu1", "conv2", "relu2"]
        ops = [nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()]
        model = nn.Sequential(OrderedDict(zip(names, ops)))
        attach_name_to_modules(model)

        search_names = ["conv1", "azertyuiop"]
        modules = find_modules_by_names(model, search_names)

        # Each name links to the right module
        self.assertFalse("azertyuiop" in modules)
Exemplo n.º 4
0
    def test_nested_modules(self):
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.block1 = nn.Sequential(
                    nn.Sequential(nn.Linear(4, 4), nn.Sigmoid(),
                                  nn.Linear(4, 1), nn.Sigmoid()),
                    nn.Sigmoid(),
                )

        model = MyModel()
        attach_name_to_modules(model)

        self.assertEqual(model.block1[0][2]._extractor_fullname, "block1.0.2")
Exemplo n.º 5
0
    def test_nested_modules(self):
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.block1 = nn.Sequential(
                    nn.Sequential(nn.Linear(4, 4), nn.Sigmoid(),
                                  nn.Linear(4, 1), nn.Sigmoid()),
                    nn.Sigmoid(),
                )

        model = MyModel()
        attach_name_to_modules(model)
        modules = find_modules_by_names(model, ["block1.0.2"])

        self.assertEqual(id(model.block1[0][2]), id(modules["block1.0.2"]))
Exemplo n.º 6
0
    def test_sequential_model(self):
        names = ["conv1", "relu1", "conv2", "relu2"]
        ops = [nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()]
        model = nn.Sequential(OrderedDict(zip(names, ops)))
        attach_name_to_modules(model)

        search_names = ["conv1", "relu2"]
        modules = find_modules_by_names(model, search_names)

        # All names have a match
        self.assertTrue(all([name in modules for name in search_names]))

        # Each name links to the right module
        self.assertEqual(id(modules["conv1"]), id(ops[0]))
        self.assertEqual(id(modules["relu2"]), id(ops[3]))
Exemplo n.º 7
0
    def test_sequential_module_inheritance(self):
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(1, 20, 5)
                self.conv2 = nn.Conv2d(20, 20, 5)

        model = MyModel()
        attach_name_to_modules(model)

        module_iter = model.children()
        module1 = next(module_iter)
        self.assertEqual(module1._extractor_fullname, "conv1")
        module2 = next(module_iter)
        self.assertEqual(module2._extractor_fullname, "conv2")
        # Only two modules
        self.assertIsNone(next(module_iter, None))