def test_validate_mix_predicate(self): # check with a mix predicate not requires grad or is not unsupported inspector = mi.ModelInspector("pred1", self.pred_mix) model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1)) for p in model[1].parameters(): p.requires_grad = False valid = inspector.validate(model) self.assertTrue(valid)
def test_validate_negative_predicate_False(self): # test when a negative predicate (e.g. not unsupported) returns false inspector = mi.ModelInspector("pred", self.pred_not_unsupported) model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1)) valid = inspector.validate(model) self.assertFalse(valid) list_len = len(inspector.violators) self.assertEqual(list_len, 1, f"violators = {inspector.violators}")
def test_validate_negative_predicate_ture(self): # test when a negative predicate (e.g. not unsupported) returns true inspector = mi.ModelInspector("pred1", self.pred_not_unsupported) model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Linear(1, 1)) valid = inspector.validate(model) self.assertTrue(valid) list_len = len(inspector.violators) self.assertEqual(list_len, 0)
def test_validate_positive_predicate_invalid(self): # test when a positive predicate (e.g. supported) returns false inspector = mi.ModelInspector("pred", self.pred_supported) model = nn.Conv1d(1, 1, 1) valid = inspector.validate(model) self.assertFalse(valid) list_len = len(inspector.violators) self.assertEqual(list_len, 1, f"violators = {inspector.violators}")
def test_complicated_case(self): def good(x): return isinstance(x, (nn.Conv2d, nn.Linear)) def bad(x): return isinstance(x, nn.modules.batchnorm._BatchNorm) inspector1 = mi.ModelInspector("good_or_bad", lambda x: good(x) or bad(x)) inspector2 = mi.ModelInspector("not_bad", lambda x: not bad(x)) model = models.resnet50() valid = inspector1.validate(model) self.assertTrue(valid, f"violators = {inspector1.violators}") self.assertEqual(len(inspector1.violators), 0, f"violators = {inspector1.violators}") valid = inspector2.validate(model) self.assertFalse(valid, f"violators = {inspector2.violators}") self.assertEqual(len(inspector2.violators), 53, f"violators = {inspector2.violators}")
def test_check_everything_flag(self): # check to see if a model does not containt nn.sequential inspector = mi.ModelInspector( "pred", lambda model: not isinstance(model, nn.Sequential), check_leaf_nodes_only=False, ) model = nn.Sequential(nn.Conv1d(1, 1, 1)) valid = inspector.validate(model) self.assertFalse(valid, f"violators = {inspector.violators}")
def test_validate_basic(self): inspector = mi.ModelInspector( "pred", lambda model: isinstance(model, nn.Linear)) model = nn.Conv1d(1, 1, 1) valid = inspector.validate(model) self.assertFalse(valid, inspector.violators)