コード例 #1
0
 def test_validate_mix_predicate(self):
     # check with a mix predicate not requires grad or is not unsupported
     inspector = mi.ModelInspector("pred1", self.pred_mix)
     model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1))
     for p in model[1].parameters():
         p.requires_grad = False
     valid = inspector.validate(model)
     self.assertTrue(valid)
コード例 #2
0
 def test_validate_negative_predicate_False(self):
     # test when a negative predicate (e.g. not unsupported) returns false
     inspector = mi.ModelInspector("pred", self.pred_not_unsupported)
     model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1))
     valid = inspector.validate(model)
     self.assertFalse(valid)
     list_len = len(inspector.violators)
     self.assertEqual(list_len, 1, f"violators = {inspector.violators}")
コード例 #3
0
 def test_validate_negative_predicate_ture(self):
     # test when a negative predicate (e.g. not unsupported) returns true
     inspector = mi.ModelInspector("pred1", self.pred_not_unsupported)
     model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Linear(1, 1))
     valid = inspector.validate(model)
     self.assertTrue(valid)
     list_len = len(inspector.violators)
     self.assertEqual(list_len, 0)
コード例 #4
0
 def test_validate_positive_predicate_invalid(self):
     # test when a positive predicate (e.g. supported) returns false
     inspector = mi.ModelInspector("pred", self.pred_supported)
     model = nn.Conv1d(1, 1, 1)
     valid = inspector.validate(model)
     self.assertFalse(valid)
     list_len = len(inspector.violators)
     self.assertEqual(list_len, 1, f"violators = {inspector.violators}")
コード例 #5
0
    def test_complicated_case(self):
        def good(x):
            return isinstance(x, (nn.Conv2d, nn.Linear))

        def bad(x):
            return isinstance(x, nn.modules.batchnorm._BatchNorm)

        inspector1 = mi.ModelInspector("good_or_bad",
                                       lambda x: good(x) or bad(x))
        inspector2 = mi.ModelInspector("not_bad", lambda x: not bad(x))
        model = models.resnet50()
        valid = inspector1.validate(model)
        self.assertTrue(valid, f"violators = {inspector1.violators}")
        self.assertEqual(len(inspector1.violators), 0,
                         f"violators = {inspector1.violators}")
        valid = inspector2.validate(model)
        self.assertFalse(valid, f"violators = {inspector2.violators}")
        self.assertEqual(len(inspector2.violators), 53,
                         f"violators = {inspector2.violators}")
コード例 #6
0
 def test_check_everything_flag(self):
     # check to see if a model does not containt nn.sequential
     inspector = mi.ModelInspector(
         "pred",
         lambda model: not isinstance(model, nn.Sequential),
         check_leaf_nodes_only=False,
     )
     model = nn.Sequential(nn.Conv1d(1, 1, 1))
     valid = inspector.validate(model)
     self.assertFalse(valid, f"violators = {inspector.violators}")
コード例 #7
0
 def test_validate_basic(self):
     inspector = mi.ModelInspector(
         "pred", lambda model: isinstance(model, nn.Linear))
     model = nn.Conv1d(1, 1, 1)
     valid = inspector.validate(model)
     self.assertFalse(valid, inspector.violators)