def test_rai_insights_cancer(self, manager_type): X_train, X_test, y_train, y_test, _, classes = \ create_cancer_data() models = create_models_classification(X_train, y_train) X_train[LABELS] = y_train X_test[LABELS] = y_test manager_args = { ManagerParams.DESIRED_CLASS: 'opposite', ManagerParams.FEATURE_IMPORTANCE: False } for model in models: run_rai_insights(model, X_train, X_test, LABELS, None, manager_type, manager_args, classes)
def test_no_model_but_serializer_provided(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() X_train[TARGET] = y_train X_test[TARGET] = y_test with pytest.raises(UserConfigValidationException) as ucve: RAIInsights(model=None, train=X_train, test=X_test, target_column=TARGET, task_type='classification', serializer={}) assert 'No valid model is specified but model serializer provided.' \ in str(ucve.value)
def test_mismatch_train_test_features(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() model = create_lightgbm_classifier(X_train, y_train) X_train[TARGET] = y_train X_test['bad_target'] = y_test with pytest.raises(UserConfigValidationException) as ucve: RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification') assert 'The features in train and test data do not match' in \ str(ucve.value)
def test_classes_passes(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() model = create_lightgbm_classifier(X_train, y_train) X_train[TARGET] = y_train X_test[TARGET] = y_test rai = RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification') # validate classes are always sorted classes = rai._classes assert np.all(classes[:-1] <= classes[1:])
def test_classes_exceptions(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() model = create_lightgbm_classifier(X_train, y_train) X_train[TARGET] = y_train X_test[TARGET] = y_test with pytest.raises(UserConfigValidationException) as ucve: RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification', classes=[0, 1, 2]) assert 'The train labels and distinct values in ' + \ 'target (train data) do not match' in str(ucve.value) y_train[0] = 2 X_train[TARGET] = y_train X_test[TARGET] = y_test with pytest.raises(UserConfigValidationException) as ucve: RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification', classes=[0, 1]) assert 'The train labels and distinct values in target ' + \ '(train data) do not match' in str(ucve.value) y_train[0] = 0 y_test[0] = 2 X_train[TARGET] = y_train X_test[TARGET] = y_test with pytest.raises(UserConfigValidationException) as ucve: RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification', classes=[0, 1]) assert 'The train labels and distinct values in target ' + \ '(test data) do not match' in str(ucve.value)
def test_unsupported_train_test_types(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() model = create_lightgbm_classifier(X_train, y_train) X_train[TARGET] = y_train X_test[TARGET] = y_test with pytest.raises(UserConfigValidationException) as ucve: RAIInsights(model=model, train=X_train.values, test=X_test.values, target_column=TARGET, task_type='classification') assert "Unsupported data type for either train or test. " + \ "Expecting pandas DataFrame for train and test." in str(ucve.value)
def test_model_analysis_incorrect_task_type(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() model = create_lightgbm_classifier(X_train, y_train) X_train[TARGET] = y_train X_test[TARGET] = y_test err_msg = ('The regression model' 'provided has a predict_proba function. ' 'Please check the task_type.') with pytest.raises(UserConfigValidationException, match=err_msg): RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='regression')
def test_model_predictions_predict(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() X_train[TARGET] = y_train X_test[TARGET] = y_test model = MagicMock() model.predict.side_effect = Exception() with pytest.raises(UserConfigValidationException) as ucve: RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification') assert 'The model passed cannot be used for getting predictions ' + \ 'via predict()' in str(ucve.value)
def test_feature_metadata(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() model = create_lightgbm_classifier(X_train, y_train) X_train[TARGET] = y_train X_test[TARGET] = y_test from responsibleai.feature_metadata import FeatureMetadata feature_metadata = FeatureMetadata(identity_feature_name='id') err_msg = ('The given identity feature name id is not present' ' in user features.') with pytest.raises(UserConfigValidationException, match=err_msg): RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification', feature_metadata=feature_metadata)
def test_validate_serializer(self): X_train, X_test, y_train, y_test, _, _ = \ create_cancer_data() model = create_lightgbm_classifier(X_train, y_train) X_train[TARGET] = y_train X_test[TARGET] = y_test with pytest.raises(UserConfigValidationException) as ucve: class LoadOnlySerializer: def __init__(self, logger=None): self._logger = logger def load(self): pass serializer = LoadOnlySerializer() RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification', serializer=serializer) assert 'The serializer does not implement save()' in str(ucve.value) with pytest.raises(UserConfigValidationException) as ucve: class SaveOnlySerializer: def __init__(self, logger=None): self._logger = logger def save(self): pass serializer = SaveOnlySerializer() RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification', serializer=serializer) assert 'The serializer does not implement load()' in str(ucve.value) with pytest.raises(UserConfigValidationException) as ucve: class Serializer: def __init__(self, logger=None): self._logger = logger def save(self): pass def load(self): pass serializer = Serializer(logger=logging.getLogger('some logger')) RAIInsights(model=model, train=X_train, test=X_test, target_column=TARGET, task_type='classification', serializer=serializer) assert 'The serializer should be serializable via pickle' in \ str(ucve.value)