Ejemplo n.º 1
0
 def test_replace_basic_case(self):
     model = nn.BatchNorm1d(10)
     model = mm.replace_all_modules(
         model, nn.BatchNorm1d, lambda _: nn.BatchNorm2d(10)
     )
     self.checkModulePresent(model, nn.BatchNorm2d)
     self.checkModuleNotPresent(model, nn.BatchNorm1d)
Ejemplo n.º 2
0
    def test_replace_sequential_case(self):
        model = nn.Sequential(nn.Conv1d(1, 2, 3), nn.Sequential(nn.Conv2d(4, 5, 6)))

        def conv(m: nn.Conv2d):
            return nn.Linear(4, 5)

        model = mm.replace_all_modules(model, nn.Conv2d, conv)
        self.checkModulePresent(model, nn.Linear)
        self.checkModuleNotPresent(model, nn.Conv2d)
    def test_module_modification_replace_example(self):
        # IMPORTANT: When changing this code you also need to update
        # the docstring for opacus.utils.module_modification.replace_all_modules()
        from torchvision.models import resnet18

        model = resnet18()
        self.assertTrue(isinstance(model.layer1[0].bn1, nn.BatchNorm2d))

        model = replace_all_modules(model, nn.BatchNorm2d, lambda _: nn.Identity())
        self.assertTrue(isinstance(model.layer1[0].bn1, nn.Identity))