def test_running_stats(self):
        inspector = dp_inspector.DPModelInspector(should_throw=False)

        self.assertTrue(inspector.validate(nn.InstanceNorm1d(16)))
        self.assertTrue(inspector.validate(nn.InstanceNorm1d(16, affine=True)))
        self.assertTrue(
            inspector.validate(nn.InstanceNorm1d(16, track_running_stats=True))
        )
        self.assertFalse(
            inspector.validate(
                nn.InstanceNorm1d(16, affine=True, track_running_stats=True)
            )
        )
    def test_unsupported_layer(self):
        class SampleNetWithTransformer(nn.Module):
            def __init__(self):
                super().__init__()

                self.fc = nn.Linear(8, 16)
                self.encoder = nn.Transformer()

            def forward(self, x):
                x = self.fc(x)
                x = self.encoder(x)
                return x

        model = SampleNetWithTransformer()
        inspector = dp_inspector.DPModelInspector(should_throw=False)
        self.assertFalse(inspector.validate(model))
    def test_conv2d(self):
        inspector = dp_inspector.DPModelInspector(should_throw=False)

        self.assertTrue(
            inspector.validate(
                nn.Conv2d(in_channels=3, out_channels=6, kernel_size=1, groups=1)
            )
        )
        self.assertTrue(
            inspector.validate(
                nn.Conv2d(in_channels=3, out_channels=6, kernel_size=1, groups=3)
            )
        )
        self.assertFalse(
            inspector.validate(
                nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=2)
            )
        )
    def test_extra_param(self):
        inspector = dp_inspector.DPModelInspector(should_throw=False)

        class SampleNetWithExtraParam(nn.Module):
            def __init__(self):
                super().__init__()

                self.fc = nn.Linear(8, 16)
                self.extra_param = nn.Parameter(torch.Tensor(16, 2))

            def forward(self, x):
                x = self.fc(x)
                x = x.matmul(self.extra_param)
                return x

        model = SampleNetWithExtraParam()
        self.assertFalse(inspector.validate(model))

        model.extra_param.requires_grad = False
        self.assertTrue(inspector.validate(model))
 def test_convert_batchnorm(self):
     inspector = dp_inspector.DPModelInspector()
     model = convert_batchnorm_modules(models.resnet50())
     self.assertTrue(inspector.validate(model))
 def test_returns_False(self):
     inspector = dp_inspector.DPModelInspector(should_throw=False)
     model = models.resnet50()
     self.assertFalse(inspector.validate(model))
 def test_raises_exception(self):
     inspector = dp_inspector.DPModelInspector()
     model = models.resnet50()
     with self.assertRaises(dp_inspector.IncompatibleModuleException):
         inspector.validate(model)