class LogisticRegressionKernelTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = LogisticRegression() model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test.iloc[:20], y_test.iloc[:20], shap='kernel', model_output='probability', X_background=shap.sample(X_train, 5), cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck', 'Embarked'], labels=['Not survived', 'Survived']) def test_shap_values(self): self.assertIsInstance(self.explainer.shap_base_value(), (np.floating, float)) self.assertTrue(self.explainer.get_shap_values_df().shape == ( len(self.explainer), len(self.explainer.merged_cols))) self.assertIsInstance(self.explainer.get_shap_values_df(), pd.DataFrame)
def setUp(self): #X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) df = pd.read_csv(Path.cwd() / "tests" / "test_assets" / "pipeline_data.csv") X = df[['age', 'fare', 'embarked', 'sex', 'pclass']] y = df['survived'].astype(int) numeric_features = ['age', 'fare'] numeric_transformer = Pipeline( steps=[('imputer', SimpleImputer( strategy='median')), ('scaler', StandardScaler())]) categorical_features = ['embarked', 'sex', 'pclass'] categorical_transformer = Pipeline( steps=[('imputer', SimpleImputer( strategy='most_frequent')), ('ordinal', OrdinalEncoder())]) preprocessor = ColumnTransformer( transformers=[('num', numeric_transformer, numeric_features), ('cat', categorical_transformer, categorical_features)]) # Append classifier to preprocessing pipeline. # Now we have a full prediction pipeline. clf = Pipeline( steps=[('preprocessor', preprocessor), ('classifier', RandomForestClassifier())]) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) clf.fit(X_train, y_train) self.explainer = ClassifierExplainer(clf, X_test, y_test)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() model = RandomForestClassifier(n_estimators=50, max_depth=4).fit(X_train, y_train) X_test.reset_index(drop=True, inplace=True) X_test.index = X_test.index.astype(str) X_test1, y_test1 = X_test.iloc[:100], y_test.iloc[:100] X_test2, y_test2 = X_test.iloc[100:], y_test.iloc[100:] self.explainer = ClassifierExplainer(model, X_test1, y_test1, cats=['Sex', 'Deck']) def index_exists_func(index): return index in X_test2.index def index_list_func(): # only returns first 50 indexes return list(X_test2.index[:50]) def y_func(index): idx = X_test2.index.get_loc(index) return y_test2.iloc[[idx]] def X_func(index): idx = X_test2.index.get_loc(index) return X_test2.iloc[[idx]] self.explainer.set_index_exists_func(index_exists_func) self.explainer.set_index_list_func(index_list_func) self.explainer.set_X_row_func(X_func) self.explainer.set_y_func(y_func)
def get_multiclass_explainer(xgboost=False, include_y=True): X_train, y_train, X_test, y_test = titanic_embarked() train_names, test_names = titanic_names() if xgboost: model = XGBClassifier().fit(X_train, y_train) else: model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train) if include_y: if xgboost: multi_explainer = ClassifierExplainer(model, X_test, y_test, model_output='logodds', cats=['Sex', 'Deck'], labels=['Queenstown', 'Southampton', 'Cherbourg']) else: multi_explainer = ClassifierExplainer(model, X_test, y_test, cats=['Sex', 'Deck'], labels=['Queenstown', 'Southampton', 'Cherbourg']) else: if xgboost: multi_explainer = ClassifierExplainer(model, X_test, model_output='logodds', cats=['Sex', 'Deck'], labels=['Queenstown', 'Southampton', 'Cherbourg']) else: multi_explainer = ClassifierExplainer(model, X_test, cats=['Sex', 'Deck'], labels=['Queenstown', 'Southampton', 'Cherbourg']) multi_explainer.calculate_properties() return multi_explainer
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, roc_auc_score, n_jobs=-1)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_train.iloc[:50], y_train.iloc[:50], cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'], cv=3)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = XGBClassifier() model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'], labels=['Not survived', 'Survived'])
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() _, self.names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=['Sex', 'Deck', 'Embarked'], labels=['Not survived', 'Survived'])
def setUp(self): X_train, y_train, X_test, y_test = titanic_embarked() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer(model, X_test, y_test, cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck'], idxs=test_names, labels=['Queenstown', 'Southampton', 'Cherbourg'])
class XGBMultiClassifierExplainerTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_embarked() train_names, test_names = titanic_names() _, self.names = titanic_names() model = XGBClassifier(n_estimators=5) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, model_output='raw', cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck'], idxs=test_names, labels=['Queenstown', 'Southampton', 'Cherbourg']) def test_graphviz_available(self): self.assertIsInstance(self.explainer.graphviz_available, bool) def test_shadow_trees(self): dt = self.explainer.shadow_trees self.assertIsInstance(dt, list) self.assertIsInstance( dt[0], dtreeviz.models.shadow_decision_tree.ShadowDecTree) def test_decisionpath_df(self): df = self.explainer.get_decisionpath_df(tree_idx=0, index=0) self.assertIsInstance(df, pd.DataFrame) df = self.explainer.get_decisionpath_df(tree_idx=0, index=self.names[0]) self.assertIsInstance(df, pd.DataFrame) df = self.explainer.get_decisionpath_df(tree_idx=0, index=self.names[0], pos_label=0) self.assertIsInstance(df, pd.DataFrame) def test_plot_trees(self): fig = self.explainer.plot_trees(index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_trees(index=self.names[0]) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_trees(index=self.names[0], highlight_tree=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_trees(index=self.names[0], pos_label=0) self.assertIsInstance(fig, go.Figure) def test_calculate_properties(self): self.explainer.calculate_properties()
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = CatBoostClassifier(iterations=100, learning_rate=0.1, verbose=0) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, roc_auc_score, shap='tree', cats=['Sex', 'Cabin', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = LGBMClassifier() model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, roc_auc_score, shap='tree', cats=['Sex', 'Cabin', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = CatBoostClassifier(iterations=100, learning_rate=0.1, verbose=0) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() _, self.names = titanic_names() model = XGBClassifier(n_estimators=5) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=['Sex', 'Cabin', 'Embarked'], idxs=test_names, labels=['Not survived', 'Survived'])
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = LogisticRegression() model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, shap='linear', cats=['Sex', 'Deck', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck', 'Embarked'], target='Survival', labels=['Not survived', 'Survived'], idxs=test_names)
def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = LogisticRegression() model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test.iloc[:20], y_test.iloc[:20], shap='kernel', model_output='probability', X_background=shap.sample(X_train, 5), cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck', 'Embarked'], labels=['Not survived', 'Survived'])
def setUp(self): X_train, y_train, X_test, y_test = titanic_embarked() train_names, test_names = titanic_names() _, self.names = titanic_names() model = XGBClassifier(n_estimators=5) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, model_output='raw', cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck', 'Embarked'], idxs=test_names, labels=['Queenstown', 'Southampton', 'Cherbourg'])
class ClassifierCVTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_train.iloc[:50], y_train.iloc[:50], cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'], cv=3) def test_cv_permutation_importances(self): self.assertIsInstance(self.explainer.permutation_importances(), pd.DataFrame) self.assertIsInstance(self.explainer.permutation_importances(pos_label=0), pd.DataFrame) def test_cv_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics(pos_label=0), dict)
class NJobsMinusOneExplainerTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, roc_auc_score, n_jobs=-1) def test_permutation_importances(self): self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame)
class ClassifierBunchTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() _, self.names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, roc_auc_score, shap='tree', cats=['Sex', 'Cabin', 'Embarked'], idxs=test_names, labels=['Not survived', 'Survived']) def test_graphviz_available(self): self.assertIsInstance(self.explainer.graphviz_available, bool) def test_decision_trees(self): dt = self.explainer.decision_trees self.assertIsInstance(dt, list) self.assertIsInstance( dt[0], dtreeviz.models.shadow_decision_tree.ShadowDecTree) def test_decisiontree_df(self): df = self.explainer.decisiontree_df(tree_idx=0, index=0) self.assertIsInstance(df, pd.DataFrame) df = self.explainer.decisiontree_df(tree_idx=0, index=self.names[0]) self.assertIsInstance(df, pd.DataFrame) def test_plot_trees(self): fig = self.explainer.plot_trees(index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_trees(index=self.names[0]) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_trees(index=self.names[0], highlight_tree=0) self.assertIsInstance(fig, go.Figure) def test_calculate_properties(self): self.explainer.calculate_properties()
def get_catboost_classifier(): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = CatBoostClassifier(iterations=100, verbose=0).fit(X_train, y_train) explainer = ClassifierExplainer( model, X_test, y_test, cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names) X_cats, y_cats = explainer.X_merged, explainer.y.astype("int") model = CatBoostClassifier(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[5, 6, 7]) explainer = ClassifierExplainer(model, X_cats, y_cats, idxs=X_test.index) explainer.calculate_properties(include_interactions=False) return explainer
def get_classification_explainer(include_y=True): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = XGBClassifier().fit(X_train, y_train) if include_y: explainer = ClassifierExplainer( model, X_test, y_test, cats=['Sex', 'Cabin', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names) else: explainer = ClassifierExplainer( model, X_test, cats=['Sex', 'Cabin', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names) explainer.calculate_properties() return explainer
def get_classification_explainer(xgboost=False, include_y=True): X_train, y_train, X_test, y_test = titanic_survive() if xgboost: model = XGBClassifier().fit(X_train, y_train) else: model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train) if include_y: explainer = ClassifierExplainer( model, X_test, y_test, cats=['Sex', 'Deck', 'Embarked'], labels=['Not survived', 'Survived']) else: explainer = ClassifierExplainer( model, X_test, cats=['Sex', 'Deck', 'Embarked'], labels=['Not survived', 'Survived']) explainer.calculate_properties() return explainer
class ClassifierBaseExplainerTestsPipeline(unittest.TestCase): def setUp(self): #X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) df = pd.read_csv(Path.cwd() / "tests" / "test_assets" / "pipeline_data.csv") X = df[['age', 'fare', 'embarked', 'sex', 'pclass']] y = df['survived'].astype(int) numeric_features = ['age', 'fare'] numeric_transformer = Pipeline( steps=[('imputer', SimpleImputer( strategy='median')), ('scaler', StandardScaler())]) categorical_features = ['embarked', 'sex', 'pclass'] categorical_transformer = Pipeline( steps=[('imputer', SimpleImputer( strategy='most_frequent')), ('ordinal', OrdinalEncoder())]) preprocessor = ColumnTransformer( transformers=[('num', numeric_transformer, numeric_features), ('cat', categorical_transformer, categorical_features)]) # Append classifier to preprocessing pipeline. # Now we have a full prediction pipeline. clf = Pipeline( steps=[('preprocessor', preprocessor), ('classifier', RandomForestClassifier())]) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) clf.fit(X_train, y_train) self.explainer = ClassifierExplainer(clf, X_test, y_test) def test_columns_ranked_by_shap(self): self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) def test_permutation_importances(self): self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics_descriptions(), dict) def test_mean_abs_shap_df(self): self.assertIsInstance(self.explainer.get_mean_abs_shap_df(), pd.DataFrame) def test_contrib_df(self): self.assertIsInstance(self.explainer.get_contrib_df(0), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value(), (np.floating, float)) def test_shap_values_shape(self): self.assertTrue(self.explainer.get_shap_values_df().shape == ( len(self.explainer), len(self.explainer.merged_cols))) def test_shap_values(self): self.assertIsInstance(self.explainer.get_shap_values_df(), pd.DataFrame) def test_pdp_df(self): self.assertIsInstance(self.explainer.pdp_df("age"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("sex"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("age", index=0), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("sex", index=0), pd.DataFrame)
class ClassifierBunchTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck', 'Embarked'], idxs=test_names, labels=['Not survived', 'Survived']) def test_pos_label(self): self.explainer.pos_label = 1 self.explainer.pos_label = "Not survived" self.assertIsInstance(self.explainer.pos_label, int) self.assertIsInstance(self.explainer.pos_label_str, str) self.assertEquals(self.explainer.pos_label, 0) self.assertEquals(self.explainer.pos_label_str, "Not survived") def test_get_prop_for_label(self): self.explainer.pos_label = 1 tmp = self.explainer.pred_percentiles self.explainer.pos_label = 0 self.assertTrue( np.alltrue( self.explainer.get_prop_for_label("pred_percentiles", 1) == tmp)) def test_pred_probas(self): self.assertIsInstance(self.explainer.pred_probas, np.ndarray) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict) def test_precision_df(self): self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame) self.assertIsInstance(self.explainer.precision_df(multiclass=True), pd.DataFrame) self.assertIsInstance(self.explainer.precision_df(quantiles=4), pd.DataFrame) def test_lift_curve_df(self): self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame) def test_prediction_result_markdown(self): self.assertIsInstance(self.explainer.prediction_result_markdown(0), str) def test_calculate_properties(self): self.explainer.calculate_properties() def test_plot_precision(self): fig = self.explainer.plot_precision() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_precision(multiclass=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5) self.assertIsInstance(fig, go.Figure) def test_plot_cumulutive_precision(self): fig = self.explainer.plot_cumulative_precision() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_cumulative_precision(percentile=0.5) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_cumulative_precision(percentile=0.1, pos_label=0) self.assertIsInstance(fig, go.Figure) def test_plot_confusion_matrix(self): fig = self.explainer.plot_confusion_matrix(normalized=False, binary=False) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_confusion_matrix(normalized=False, binary=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_confusion_matrix(normalized=True, binary=False) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_confusion_matrix(normalized=True, binary=True) self.assertIsInstance(fig, go.Figure) def test_plot_lift_curve(self): fig = self.explainer.plot_lift_curve() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(percentage=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(cutoff=0.5) self.assertIsInstance(fig, go.Figure) def test_plot_lift_curve(self): fig = self.explainer.plot_lift_curve() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(percentage=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(cutoff=0.5) self.assertIsInstance(fig, go.Figure) def test_plot_classification(self): fig = self.explainer.plot_classification() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_classification(percentage=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_classification(cutoff=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_classification(cutoff=1) self.assertIsInstance(fig, go.Figure) def test_plot_roc_auc(self): fig = self.explainer.plot_roc_auc(0.5) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_roc_auc(0.0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_roc_auc(1.0) self.assertIsInstance(fig, go.Figure) def test_plot_pr_auc(self): fig = self.explainer.plot_pr_auc(0.5) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pr_auc(0.0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pr_auc(1.0) self.assertIsInstance(fig, go.Figure)
class ClassifierBaseExplainerTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, roc_auc_score, cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck', 'Embarked'], target='Survival', labels=['Not survived', 'Survived'], idxs=test_names) def test_explainer_len(self): self.assertEqual(len(self.explainer), len(titanic_survive()[2])) def test_int_idx(self): self.assertEqual(self.explainer.get_int_idx(titanic_names()[1][0]), 0) def test_random_index(self): self.assertIsInstance(self.explainer.random_index(), int) self.assertIsInstance(self.explainer.random_index(return_str=True), str) def test_preds(self): self.assertIsInstance(self.explainer.preds, np.ndarray) def test_pred_percentiles(self): self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray) def test_columns_ranked_by_shap(self): self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True), list) def test_equivalent_col(self): self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Gender") self.assertEqual(self.explainer.equivalent_col("Gender"), "Sex_female") self.assertEqual(self.explainer.equivalent_col("Deck"), "Deck_A") self.assertEqual(self.explainer.equivalent_col("Deck_A"), "Deck") self.assertIsNone(self.explainer.equivalent_col("random")) def test_get_col(self): self.assertIsInstance(self.explainer.get_col("Gender"), pd.Series) self.assertEqual(self.explainer.get_col("Gender").dtype, "object") self.assertIsInstance(self.explainer.get_col("Deck"), pd.Series) self.assertEqual(self.explainer.get_col("Deck").dtype, "object") self.assertIsInstance(self.explainer.get_col("Age"), pd.Series) self.assertEqual(self.explainer.get_col("Age").dtype, np.float) def test_permutation_importances(self): self.assertIsInstance(self.explainer.permutation_importances, pd.DataFrame) self.assertIsInstance(self.explainer.permutation_importances_cats, pd.DataFrame) def test_X_cats(self): self.assertIsInstance(self.explainer.X_cats, pd.DataFrame) def test_columns_cats(self): self.assertIsInstance(self.explainer.columns_cats, list) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics_markdown(), str) def test_mean_abs_shap_df(self): self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) def test_top_interactions(self): self.assertIsInstance(self.explainer.shap_top_interactions("Age"), list) self.assertIsInstance( self.explainer.shap_top_interactions("Age", topx=4), list) self.assertIsInstance( self.explainer.shap_top_interactions("Age", cats=True), list) self.assertIsInstance( self.explainer.shap_top_interactions("Gender", cats=True), list) def test_permutation_importances_df(self): self.assertIsInstance(self.explainer.permutation_importances_df(), pd.DataFrame) self.assertIsInstance( self.explainer.permutation_importances_df(topx=3), pd.DataFrame) self.assertIsInstance( self.explainer.permutation_importances_df(cats=True), pd.DataFrame) self.assertIsInstance( self.explainer.permutation_importances_df(cutoff=0.01), pd.DataFrame) def test_contrib_df(self): self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_df(0, cats=False), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_df(0, topx=3), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_df(0, sort='high-to-low'), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_df(0, sort='low-to-high'), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_df(0, sort='importance'), pd.DataFrame) self.assertIsInstance( self.explainer.contrib_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) self.assertIsInstance( self.explainer.contrib_df(X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame) def test_contrib_summary_df(self): self.assertIsInstance(self.explainer.contrib_summary_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3), pd.DataFrame) self.assertIsInstance( self.explainer.contrib_summary_df(0, sort='low-to-high'), pd.DataFrame) self.assertIsInstance( self.explainer.contrib_summary_df(0, sort='high-to-low'), pd.DataFrame) self.assertIsInstance( self.explainer.contrib_summary_df(0, sort='importance'), pd.DataFrame) self.assertIsInstance( self.explainer.contrib_summary_df( X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) self.assertIsInstance( self.explainer.contrib_summary_df( X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame) def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value, (np.floating, float)) def test_shap_values_shape(self): self.assertTrue( self.explainer.shap_values.shape == (len(self.explainer), len(self.explainer.columns))) def test_shap_values(self): self.assertIsInstance(self.explainer.shap_values, np.ndarray) self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray) def test_shap_interaction_values(self): self.assertIsInstance(self.explainer.shap_interaction_values, np.ndarray) self.assertIsInstance(self.explainer.shap_interaction_values_cats, np.ndarray) def test_mean_abs_shap(self): self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame) self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame) def test_calculate_properties(self): self.explainer.calculate_properties() def test_shap_interaction_values_by_col(self): self.assertIsInstance( self.explainer.shap_interaction_values_by_col("Age"), np.ndarray) self.assertEquals( self.explainer.shap_interaction_values_by_col("Age").shape, self.explainer.shap_values.shape) self.assertEquals( self.explainer.shap_interaction_values_by_col("Age", cats=True).shape, self.explainer.shap_values_cats.shape) def test_pdp_result(self): self.assertIsInstance(self.explainer.get_pdp_result("Age"), pdpbox.pdp.PDPIsolate) self.assertIsInstance(self.explainer.get_pdp_result("Gender"), pdpbox.pdp.PDPIsolate) self.assertIsInstance(self.explainer.get_pdp_result("Age", index=0), pdpbox.pdp.PDPIsolate) self.assertIsInstance(self.explainer.get_pdp_result("Gender", index=0), pdpbox.pdp.PDPIsolate) self.assertIsInstance( self.explainer.get_pdp_result("Age", X_row=self.explainer.X.iloc[[0]]), pdpbox.pdp.PDPIsolate) self.assertIsInstance( self.explainer.get_pdp_result( "Age", X_row=self.explainer.X_cats.iloc[[0]]), pdpbox.pdp.PDPIsolate) def test_get_dfs(self): cols_df, shap_df, contribs_df = self.explainer.get_dfs() self.assertIsInstance(cols_df, pd.DataFrame) self.assertIsInstance(shap_df, pd.DataFrame) self.assertIsInstance(contribs_df, pd.DataFrame) def test_plot_importances(self): fig = self.explainer.plot_importances() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(kind='permutation') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(cats=True) self.assertIsInstance(fig, go.Figure) def test_plot_interactions(self): fig = self.explainer.plot_interactions("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions("Sex_female") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions("Gender") self.assertIsInstance(fig, go.Figure) def test_plot_shap_contributions(self): fig = self.explainer.plot_shap_contributions(0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, cats=False) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, cutoff=0.05) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, sort='high-to-low') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, sort='low-to-high') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, sort='importance') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions( X_row=self.explainer.X.iloc[[0]], sort='importance') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions( X_row=self.explainer.X_cats.iloc[[0]], sort='importance') self.assertIsInstance(fig, go.Figure) def test_plot_shap_summary(self): fig = self.explainer.plot_shap_summary() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_summary(topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_summary(cats=True) self.assertIsInstance(fig, go.Figure) def test_plot_shap_interaction_summary(self): fig = self.explainer.plot_shap_interaction_summary("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Age", topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Sex_female", topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Gender") self.assertIsInstance(fig, go.Figure) def test_plot_shap_dependence(self): fig = self.explainer.plot_shap_dependence("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Sex_female") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", "Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Sex_female", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", highlight_index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Gender", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_shap_interaction(self): fig = self.explainer.plot_shap_dependence("Age", "Sex_female") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Sex_female", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Gender", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", "Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", "Sex_female", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_pdp(self): fig = self.explainer.plot_pdp("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender", index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]]) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X_cats.iloc[[0]]) self.assertIsInstance(fig, go.Figure) def test_yaml(self): yaml = self.explainer.to_yaml() self.assertIsInstance(yaml, str)
class MultiClassClassifierBunchTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_embarked() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck'], idxs=test_names, labels=['Queenstown', 'Southampton', 'Cherbourg']) def test_preds(self): self.assertIsInstance(self.explainer.preds, np.ndarray) def test_pred_percentiles(self): self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray) def test_columns_ranked_by_shap(self): self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True), list) def test_equivalent_col(self): self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Gender") self.assertEqual(self.explainer.equivalent_col("Gender"), "Sex_female") self.assertIsNone(self.explainer.equivalent_col("random")) def test_get_col(self): self.assertIsInstance(self.explainer.get_col("Gender"), pd.Series) self.assertEqual(self.explainer.get_col("Gender").dtype, "object") self.assertIsInstance(self.explainer.get_col("Age"), pd.Series) self.assertEqual(self.explainer.get_col("Age").dtype, np.float) def test_permutation_importances(self): self.assertIsInstance(self.explainer.permutation_importances, pd.DataFrame) self.assertIsInstance(self.explainer.permutation_importances_cats, pd.DataFrame) def test_X_cats(self): self.assertIsInstance(self.explainer.X_cats, pd.DataFrame) def test_columns_cats(self): self.assertIsInstance(self.explainer.columns_cats, list) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics_descriptions(), dict) def test_mean_abs_shap_df(self): self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) def test_top_interactions(self): self.assertIsInstance(self.explainer.shap_top_interactions("Age"), list) self.assertIsInstance( self.explainer.shap_top_interactions("Age", topx=4), list) self.assertIsInstance( self.explainer.shap_top_interactions("Age", cats=True), list) self.assertIsInstance( self.explainer.shap_top_interactions("Gender", cats=True), list) def test_permutation_importances_df(self): self.assertIsInstance(self.explainer.permutation_importances_df(), pd.DataFrame) self.assertIsInstance( self.explainer.permutation_importances_df(topx=3), pd.DataFrame) self.assertIsInstance( self.explainer.permutation_importances_df(cats=True), pd.DataFrame) self.assertIsInstance( self.explainer.permutation_importances_df(cutoff=0.01), pd.DataFrame) def test_contrib_df(self): self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_df(0, cats=False), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_df(0, topx=3), pd.DataFrame) def test_contrib_summary_df(self): self.assertIsInstance(self.explainer.contrib_summary_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3), pd.DataFrame) self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3), pd.DataFrame) def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value, (np.floating, float)) def test_shap_values_shape(self): self.assertTrue( self.explainer.shap_values.shape == (len(self.explainer), len(self.explainer.columns))) def test_shap_values(self): self.assertIsInstance(self.explainer.shap_values, np.ndarray) self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray) def test_shap_interaction_values(self): self.assertIsInstance(self.explainer.shap_interaction_values, np.ndarray) self.assertIsInstance(self.explainer.shap_interaction_values_cats, np.ndarray) def test_mean_abs_shap(self): self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame) self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame) def test_calculate_properties(self): self.explainer.calculate_properties() def test_shap_interaction_values_by_col(self): self.assertIsInstance( self.explainer.shap_interaction_values_by_col("Age"), np.ndarray) self.assertEqual( self.explainer.shap_interaction_values_by_col("Age").shape, self.explainer.shap_values.shape) self.assertEqual( self.explainer.shap_interaction_values_by_col("Age", cats=True).shape, self.explainer.shap_values_cats.shape) def test_pdp_df(self): self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Age", index=0), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Gender", index=0), pd.DataFrame) def test_get_dfs(self): cols_df, shap_df, contribs_df = self.explainer.get_dfs() self.assertIsInstance(cols_df, pd.DataFrame) self.assertIsInstance(shap_df, pd.DataFrame) self.assertIsInstance(contribs_df, pd.DataFrame) def test_plot_importances(self): fig = self.explainer.plot_importances() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(kind='permutation') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(cats=True) self.assertIsInstance(fig, go.Figure) def test_plot_interactions(self): fig = self.explainer.plot_interactions("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions("Sex_female") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions("Gender") self.assertIsInstance(fig, go.Figure) def test_plot_shap_contributions(self): fig = self.explainer.plot_shap_contributions(0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, cats=False) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_contributions(0, cutoff=0.05) self.assertIsInstance(fig, go.Figure) def test_plot_shap_summary(self): fig = self.explainer.plot_shap_summary() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_summary(topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_summary(cats=True) self.assertIsInstance(fig, go.Figure) def test_plot_shap_interaction_summary(self): fig = self.explainer.plot_shap_interaction_summary("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Age", topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Sex_female", topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_interaction_summary("Gender") self.assertIsInstance(fig, go.Figure) def test_plot_shap_dependence(self): fig = self.explainer.plot_shap_dependence("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Sex_female") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", "Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Sex_female", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", highlight_index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Gender", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_shap_interaction(self): fig = self.explainer.plot_shap_dependence("Age", "Sex_female") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Sex_female", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Gender", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", "Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_shap_dependence("Age", "Sex_female", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_pdp(self): fig = self.explainer.plot_pdp("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender", index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", index=0) self.assertIsInstance(fig, go.Figure) def test_pos_label(self): self.explainer.pos_label = 1 self.explainer.pos_label = "Southampton" self.assertIsInstance(self.explainer.pos_label, int) self.assertIsInstance(self.explainer.pos_label_str, str) self.assertEqual(self.explainer.pos_label, 1) self.assertEqual(self.explainer.pos_label_str, "Southampton") def test_get_prop_for_label(self): self.explainer.pos_label = 1 tmp = self.explainer.pred_percentiles self.explainer.pos_label = 0 self.assertTrue( np.alltrue( self.explainer.get_prop_for_label("pred_percentiles", 1) == tmp)) def test_pred_probas(self): self.assertIsInstance(self.explainer.pred_probas, np.ndarray) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict) def test_precision_df(self): self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame) self.assertIsInstance(self.explainer.precision_df(multiclass=True), pd.DataFrame) self.assertIsInstance(self.explainer.precision_df(quantiles=4), pd.DataFrame) def test_lift_curve_df(self): self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame) def test_prediction_result_markdown(self): self.assertIsInstance(self.explainer.prediction_result_markdown(0), str) def test_calculate_properties(self): self.explainer.calculate_properties() def test_plot_precision(self): fig = self.explainer.plot_precision() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_precision(multiclass=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5) self.assertIsInstance(fig, go.Figure) def test_plot_cumulative_precision(self): fig = self.explainer.plot_cumulative_precision() self.assertIsInstance(fig, go.Figure) def test_plot_confusion_matrix(self): fig = self.explainer.plot_confusion_matrix(normalized=False, binary=False) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_confusion_matrix(normalized=False, binary=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_confusion_matrix(normalized=True, binary=False) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_confusion_matrix(normalized=True, binary=True) self.assertIsInstance(fig, go.Figure) def test_plot_lift_curve(self): fig = self.explainer.plot_lift_curve() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(percentage=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(cutoff=0.5) self.assertIsInstance(fig, go.Figure) def test_plot_lift_curve(self): fig = self.explainer.plot_lift_curve() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(percentage=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_lift_curve(cutoff=0.5) self.assertIsInstance(fig, go.Figure) def test_plot_classification(self): fig = self.explainer.plot_classification() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_classification(percentage=True) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_classification(cutoff=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_classification(cutoff=1) self.assertIsInstance(fig, go.Figure) def test_plot_roc_auc(self): fig = self.explainer.plot_roc_auc(0.5) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_roc_auc(0.0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_roc_auc(1.0) self.assertIsInstance(fig, go.Figure) def test_plot_pr_auc(self): fig = self.explainer.plot_pr_auc(0.5) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pr_auc(0.0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pr_auc(1.0) self.assertIsInstance(fig, go.Figure)
class LogisticRegressionTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = LogisticRegression() model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, shap='linear', cats=['Sex', 'Deck', 'Embarked'], labels=['Not survived', 'Survived'], idxs=test_names) def test_preds(self): self.assertIsInstance(self.explainer.preds, np.ndarray) def test_pred_percentiles(self): self.assertIsInstance(self.explainer.pred_percentiles(), np.ndarray) def test_columns_ranked_by_shap(self): self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) def test_permutation_importances(self): self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics_descriptions(), dict) def test_mean_abs_shap_df(self): self.assertIsInstance(self.explainer.get_mean_abs_shap_df(), pd.DataFrame) def test_contrib_df(self): self.assertIsInstance(self.explainer.get_contrib_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_df(0, topx=3), pd.DataFrame) def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value(), (np.floating, float)) def test_shap_values_shape(self): self.assertTrue(self.explainer.get_shap_values_df().shape == ( len(self.explainer), len(self.explainer.merged_cols))) def test_shap_values(self): self.assertIsInstance(self.explainer.get_shap_values_df(), pd.DataFrame) def test_mean_abs_shap(self): self.assertIsInstance(self.explainer.get_mean_abs_shap_df(), pd.DataFrame) def test_calculate_properties(self): self.explainer.calculate_properties(include_interactions=False) def test_pdp_df(self): self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Sex"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Age", index=0), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Sex", index=0), pd.DataFrame) def test_pos_label(self): self.explainer.pos_label = 1 self.explainer.pos_label = "Not survived" self.assertIsInstance(self.explainer.pos_label, int) self.assertIsInstance(self.explainer.pos_label_str, str) self.assertEqual(self.explainer.pos_label, 0) self.assertEqual(self.explainer.pos_label_str, "Not survived") def test_pred_probas(self): self.assertIsInstance(self.explainer.pred_probas(), np.ndarray) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict) def test_precision_df(self): self.assertIsInstance(self.explainer.get_precision_df(), pd.DataFrame) self.assertIsInstance(self.explainer.get_precision_df(multiclass=True), pd.DataFrame) self.assertIsInstance(self.explainer.get_precision_df(quantiles=4), pd.DataFrame) def test_lift_curve_df(self): self.assertIsInstance(self.explainer.get_liftcurve_df(), pd.DataFrame)
class ClassifierBaseExplainerTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_survive() train_names, test_names = titanic_names() model = RandomForestClassifier(n_estimators=5, max_depth=2) model.fit(X_train, y_train) self.explainer = ClassifierExplainer( model, X_test, y_test, cats=[{ 'Gender': ['Sex_female', 'Sex_male', 'Sex_nan'] }, 'Deck', 'Embarked'], target='Survival', labels=['Not survived', 'Survived'], idxs=test_names) def test_explainer_len(self): self.assertEqual(len(self.explainer), len(titanic_survive()[2])) def test_int_idx(self): self.assertEqual(self.explainer.get_idx(titanic_names()[1][0]), 0) def test_random_index(self): self.assertIsInstance(self.explainer.random_index(), int) self.assertIsInstance(self.explainer.random_index(return_str=True), str) def test_preds(self): self.assertIsInstance(self.explainer.preds, np.ndarray) def test_row_from_input(self): input_row = self.explainer.get_row_from_input( self.explainer.X.iloc[[0]].values.tolist()) self.assertIsInstance(input_row, pd.DataFrame) input_row = self.explainer.get_row_from_input( self.explainer.X_merged.iloc[[0]].values.tolist()) self.assertIsInstance(input_row, pd.DataFrame) input_row = self.explainer.get_row_from_input(self.explainer.X_merged[ self.explainer.columns_ranked_by_shap()].iloc[[0]].values.tolist(), ranked_by_shap=True) self.assertIsInstance(input_row, pd.DataFrame) def test_pred_percentiles(self): self.assertIsInstance(self.explainer.pred_percentiles(), np.ndarray) def test_columns_ranked_by_shap(self): self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) def test_get_col(self): self.assertIsInstance(self.explainer.get_col("Gender"), pd.Series) self.assertTrue(is_categorical_dtype(self.explainer.get_col("Gender"))) self.assertIsInstance(self.explainer.get_col("Deck"), pd.Series) self.assertTrue(is_categorical_dtype(self.explainer.get_col("Deck"))) self.assertIsInstance(self.explainer.get_col("Age"), pd.Series) self.assertTrue(is_numeric_dtype(self.explainer.get_col("Age"))) def test_permutation_importances(self): self.assertIsInstance(self.explainer.permutation_importances(), pd.DataFrame) def test_X_cats(self): self.assertIsInstance(self.explainer.X_cats, pd.DataFrame) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) def test_mean_abs_shap_df(self): self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) def test_top_interactions(self): self.assertIsInstance(self.explainer.top_shap_interactions("Age"), list) self.assertIsInstance( self.explainer.top_shap_interactions("Age", topx=4), list) def test_permutation_importances_df(self): self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame) self.assertIsInstance( self.explainer.get_permutation_importances_df(topx=3), pd.DataFrame) self.assertIsInstance( self.explainer.get_permutation_importances_df(cutoff=0.01), pd.DataFrame) def test_contrib_df(self): self.assertIsInstance(self.explainer.get_contrib_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_df(0, topx=3), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_df(0, sort='high-to-low'), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_df(0, sort='low-to-high'), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_df(0, sort='importance'), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) def test_contrib_summary_df(self): self.assertIsInstance(self.explainer.get_contrib_summary_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_summary_df(0, topx=3), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_summary_df(0, round=3), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_summary_df(0, sort='low-to-high'), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_summary_df(0, sort='high-to-low'), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_summary_df(0, sort='importance'), pd.DataFrame) self.assertIsInstance( self.explainer.get_contrib_summary_df( X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value(), (np.floating, float)) def test_shap_values_shape(self): self.assertTrue(self.explainer.get_shap_values_df().shape == ( len(self.explainer), len(self.explainer.merged_cols))) def test_shap_values(self): self.assertIsInstance(self.explainer.get_shap_values_df(), pd.DataFrame) def test_shap_interaction_values(self): self.assertIsInstance(self.explainer.shap_interaction_values(), np.ndarray) def test_mean_abs_shap_df(self): self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) def test_calculate_properties(self): self.explainer.calculate_properties() def test_shap_interaction_values_by_col(self): self.assertIsInstance( self.explainer.shap_interaction_values_for_col("Age"), np.ndarray) self.assertEqual( self.explainer.shap_interaction_values_for_col("Age").shape, self.explainer.get_shap_values_df().shape) def test_prediction_result_df(self): df = self.explainer.prediction_result_df(0) self.assertIsInstance(df, pd.DataFrame) def test_pdp_df(self): self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Age", index=0), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Gender", index=0), pd.DataFrame) self.assertIsInstance( self.explainer.pdp_df("Age", X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) def test_memory_usage(self): self.assertIsInstance(self.explainer.memory_usage(), pd.DataFrame) self.assertIsInstance(self.explainer.memory_usage(cutoff=1000), pd.DataFrame) def test_plot_importances(self): fig = self.explainer.plot_importances() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(kind='permutation') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(topx=3) self.assertIsInstance(fig, go.Figure) def test_plot_interactions(self): fig = self.explainer.plot_interactions_importance("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions_importance("Gender") self.assertIsInstance(fig, go.Figure) def test_plot_contributions(self): fig = self.explainer.plot_contributions(0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, cutoff=0.05) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='high-to-low') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='low-to-high') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='importance') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions( X_row=self.explainer.X.iloc[[0]], sort='importance') self.assertIsInstance(fig, go.Figure) def test_plot_shap_detailed(self): fig = self.explainer.plot_importances_detailed() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances_detailed(topx=3) self.assertIsInstance(fig, go.Figure) def test_plot_interactions_detailed(self): fig = self.explainer.plot_interactions_detailed("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions_detailed("Age", topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions_detailed("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions_detailed("Gender") self.assertIsInstance(fig, go.Figure) def test_plot_dependence(self): fig = self.explainer.plot_dependence("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_dependence("Age", "Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_dependence("Age", highlight_index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_dependence("Gender", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_interaction(self): fig = self.explainer.plot_interaction("Gender", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interaction("Age", "Gender", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_pdp(self): fig = self.explainer.plot_pdp("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender", index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]]) self.assertIsInstance(fig, go.Figure) def test_yaml(self): yaml = self.explainer.to_yaml() self.assertIsInstance(yaml, str)