예제 #1
0
    def test_nested_model(self):
        class Block(torch.nn.Module):
            def __init__(self):
                super(Block, self).__init__()
                self.conv = torch.nn.Conv2d(1, 2, 3)
                self.bn = torch.nn.BatchNorm2d(1)

        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.layers = torch.nn.Sequential(
                    Block(),
                    torch.nn.Sequential(
                        Block(),
                        torch.nn.BatchNorm2d(1)
                    ),
                    torch.nn.BatchNorm2d(1)
                )

        model = Model()
        opts = TestGroupNormConversion.create_opts()
        models.replace_bn(model, opts)

        assert isinstance(model.layers[0].conv, torch.nn.Conv2d) and \
            isinstance(model.layers[0].bn, torch.nn.GroupNorm) and \
            isinstance(model.layers[1][0].conv, torch.nn.Conv2d) and \
            isinstance(model.layers[1][0].bn, torch.nn.GroupNorm) and \
            isinstance(model.layers[2], torch.nn.GroupNorm)
예제 #2
0
    def test_single_element_model(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.bn = torch.nn.BatchNorm2d(1)

        model = Model()
        group_norm = TestGroupNormConversion.create_norm()
        models.replace_bn(model, group_norm)
        assert isinstance(model.bn, torch.nn.GroupNorm)
예제 #3
0
    def test_sequential_model(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.layers = torch.nn.Sequential(torch.nn.Conv2d(1, 2, 3),
                                                  torch.nn.BatchNorm2d(1))

        model = Model()
        group_norm = TestGroupNormConversion.create_norm()
        models.replace_bn(model, group_norm)

        assert isinstance(model.layers[1], torch.nn.GroupNorm) and isinstance(
            model.layers[0], torch.nn.Conv2d)