def test_running_stats(self): inspector = dp_inspector.DPModelInspector(should_throw=False) self.assertTrue(inspector.validate(nn.InstanceNorm1d(16))) self.assertTrue(inspector.validate(nn.InstanceNorm1d(16, affine=True))) self.assertTrue( inspector.validate(nn.InstanceNorm1d(16, track_running_stats=True)) ) self.assertFalse( inspector.validate( nn.InstanceNorm1d(16, affine=True, track_running_stats=True) ) )
def test_unsupported_layer(self): class SampleNetWithTransformer(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(8, 16) self.encoder = nn.Transformer() def forward(self, x): x = self.fc(x) x = self.encoder(x) return x model = SampleNetWithTransformer() inspector = dp_inspector.DPModelInspector(should_throw=False) self.assertFalse(inspector.validate(model))
def test_conv2d(self): inspector = dp_inspector.DPModelInspector(should_throw=False) self.assertTrue( inspector.validate( nn.Conv2d(in_channels=3, out_channels=6, kernel_size=1, groups=1) ) ) self.assertTrue( inspector.validate( nn.Conv2d(in_channels=3, out_channels=6, kernel_size=1, groups=3) ) ) self.assertFalse( inspector.validate( nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=2) ) )
def test_extra_param(self): inspector = dp_inspector.DPModelInspector(should_throw=False) class SampleNetWithExtraParam(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(8, 16) self.extra_param = nn.Parameter(torch.Tensor(16, 2)) def forward(self, x): x = self.fc(x) x = x.matmul(self.extra_param) return x model = SampleNetWithExtraParam() self.assertFalse(inspector.validate(model)) model.extra_param.requires_grad = False self.assertTrue(inspector.validate(model))
def test_convert_batchnorm(self): inspector = dp_inspector.DPModelInspector() model = convert_batchnorm_modules(models.resnet50()) self.assertTrue(inspector.validate(model))
def test_returns_False(self): inspector = dp_inspector.DPModelInspector(should_throw=False) model = models.resnet50() self.assertFalse(inspector.validate(model))
def test_raises_exception(self): inspector = dp_inspector.DPModelInspector() model = models.resnet50() with self.assertRaises(dp_inspector.IncompatibleModuleException): inspector.validate(model)