def get_multiclass_explainer(xgboost=False, include_y=True):
    X_train, y_train, X_test, y_test = titanic_embarked()
    train_names, test_names = titanic_names()
    if xgboost:
        model = XGBClassifier().fit(X_train, y_train)
    else:
        model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train)

    if include_y:
        if xgboost:
            multi_explainer = ClassifierExplainer(model, X_test, y_test,
                                            model_output='logodds',
                                            cats=['Sex', 'Deck'], 
                                            labels=['Queenstown', 'Southampton', 'Cherbourg'])
        else:
            multi_explainer = ClassifierExplainer(model, X_test, y_test,
                                            cats=['Sex', 'Deck'], 
                                            labels=['Queenstown', 'Southampton', 'Cherbourg'])
    else:
        if xgboost:
            multi_explainer = ClassifierExplainer(model, X_test, 
                                            model_output='logodds',
                                            cats=['Sex', 'Deck'], 
                                            labels=['Queenstown', 'Southampton', 'Cherbourg'])
        else:
            multi_explainer = ClassifierExplainer(model, X_test, 
                                            cats=['Sex', 'Deck'], 
                                            labels=['Queenstown', 'Southampton', 'Cherbourg'])

    multi_explainer.calculate_properties()
    return multi_explainer
class XGBMultiClassifierExplainerTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_embarked()
        train_names, test_names = titanic_names()
        _, self.names = titanic_names()

        model = XGBClassifier(n_estimators=5)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            model_output='raw',
            cats=[{
                'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']
            }, 'Deck'],
            idxs=test_names,
            labels=['Queenstown', 'Southampton', 'Cherbourg'])

    def test_graphviz_available(self):
        self.assertIsInstance(self.explainer.graphviz_available, bool)

    def test_shadow_trees(self):
        dt = self.explainer.shadow_trees
        self.assertIsInstance(dt, list)
        self.assertIsInstance(
            dt[0], dtreeviz.models.shadow_decision_tree.ShadowDecTree)

    def test_decisionpath_df(self):
        df = self.explainer.get_decisionpath_df(tree_idx=0, index=0)
        self.assertIsInstance(df, pd.DataFrame)

        df = self.explainer.get_decisionpath_df(tree_idx=0,
                                                index=self.names[0])
        self.assertIsInstance(df, pd.DataFrame)

        df = self.explainer.get_decisionpath_df(tree_idx=0,
                                                index=self.names[0],
                                                pos_label=0)
        self.assertIsInstance(df, pd.DataFrame)

    def test_plot_trees(self):
        fig = self.explainer.plot_trees(index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_trees(index=self.names[0])
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_trees(index=self.names[0], highlight_tree=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_trees(index=self.names[0], pos_label=0)
        self.assertIsInstance(fig, go.Figure)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()
Esempio n. 3
0
class SkorchClassifierTests(unittest.TestCase):
    def setUp(self):
        model, X, y = get_skorch_classifier()
        self.explainer = ClassifierExplainer(model, X, y)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas(), np.ndarray)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.get_permutation_importances_df(),
                              pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value(),
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(self.explainer.get_shap_values_df().shape == (
            len(self.explainer), len(self.explainer.merged_cols)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.get_shap_values_df(),
                              pd.DataFrame)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.get_mean_abs_shap_df(),
                              pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties(include_interactions=False)

    def test_pdp_df(self):
        self.assertIsInstance(self.explainer.pdp_df("col1"), pd.DataFrame)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.get_precision_df(), pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.get_liftcurve_df(), pd.DataFrame)

    def test_prediction_result_df(self):
        self.assertIsInstance(self.explainer.prediction_result_df(0),
                              pd.DataFrame)
Esempio n. 4
0
class ClassifierBunchTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()
        _, self.names = titanic_names()

        model = RandomForestClassifier(n_estimators=5, max_depth=2)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            roc_auc_score,
            shap='tree',
            cats=['Sex', 'Cabin', 'Embarked'],
            idxs=test_names,
            labels=['Not survived', 'Survived'])

    def test_graphviz_available(self):
        self.assertIsInstance(self.explainer.graphviz_available, bool)

    def test_decision_trees(self):
        dt = self.explainer.decision_trees
        self.assertIsInstance(dt, list)
        self.assertIsInstance(
            dt[0], dtreeviz.models.shadow_decision_tree.ShadowDecTree)

    def test_decisiontree_df(self):
        df = self.explainer.decisiontree_df(tree_idx=0, index=0)
        self.assertIsInstance(df, pd.DataFrame)

        df = self.explainer.decisiontree_df(tree_idx=0, index=self.names[0])
        self.assertIsInstance(df, pd.DataFrame)

    def test_plot_trees(self):
        fig = self.explainer.plot_trees(index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_trees(index=self.names[0])
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_trees(index=self.names[0], highlight_tree=0)
        self.assertIsInstance(fig, go.Figure)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()
def get_catboost_classifier():
    X_train, y_train, X_test, y_test = titanic_survive()
    train_names, test_names = titanic_names()

    model = CatBoostClassifier(iterations=100, verbose=0).fit(X_train, y_train)
    explainer = ClassifierExplainer(
                        model, X_test, y_test, 
                        cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 
                                            'Deck', 'Embarked'],
                        labels=['Not survived', 'Survived'],
                        idxs=test_names)

    X_cats, y_cats = explainer.X_merged, explainer.y.astype("int")
    model = CatBoostClassifier(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[5, 6, 7])
    explainer = ClassifierExplainer(model, X_cats, y_cats, idxs=X_test.index)
    explainer.calculate_properties(include_interactions=False)
    return explainer
def get_classification_explainer(include_y=True):
    X_train, y_train, X_test, y_test = titanic_survive()
    train_names, test_names = titanic_names()
    model = XGBClassifier().fit(X_train, y_train)
    if include_y:
        explainer = ClassifierExplainer(
                            model, X_test, y_test, 
                            cats=['Sex', 'Cabin', 'Embarked'],
                            labels=['Not survived', 'Survived'],
                            idxs=test_names)
    else:
        explainer = ClassifierExplainer(
                            model, X_test, 
                            cats=['Sex', 'Cabin', 'Embarked'],
                            labels=['Not survived', 'Survived'],
                            idxs=test_names)

    explainer.calculate_properties()
    return explainer
def get_classification_explainer(xgboost=False, include_y=True):
    X_train, y_train, X_test, y_test = titanic_survive()
    if xgboost:
        model = XGBClassifier().fit(X_train, y_train)
    else:
        model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train)
    if include_y:
        explainer = ClassifierExplainer(
                            model, X_test, y_test, 
                            cats=['Sex', 'Deck', 'Embarked'],
                            labels=['Not survived', 'Survived'])
    else:
        explainer = ClassifierExplainer(
                            model, X_test, 
                            cats=['Sex', 'Deck', 'Embarked'],
                            labels=['Not survived', 'Survived'])

    explainer.calculate_properties()
    return explainer
Esempio n. 8
0
class LogisticRegressionTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = LogisticRegression()
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            shap='linear',
            cats=['Sex', 'Deck', 'Embarked'],
            labels=['Not survived', 'Survived'],
            idxs=test_names)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_percentiles(self):
        self.assertIsInstance(self.explainer.pred_percentiles(), np.ndarray)

    def test_columns_ranked_by_shap(self):
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.get_permutation_importances_df(),
                              pd.DataFrame)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics_descriptions(), dict)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.get_mean_abs_shap_df(),
                              pd.DataFrame)

    def test_contrib_df(self):
        self.assertIsInstance(self.explainer.get_contrib_df(0), pd.DataFrame)
        self.assertIsInstance(self.explainer.get_contrib_df(0, topx=3),
                              pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value(),
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(self.explainer.get_shap_values_df().shape == (
            len(self.explainer), len(self.explainer.merged_cols)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.get_shap_values_df(),
                              pd.DataFrame)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.get_mean_abs_shap_df(),
                              pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties(include_interactions=False)

    def test_pdp_df(self):
        self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Sex"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Age", index=0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Sex", index=0),
                              pd.DataFrame)

    def test_pos_label(self):
        self.explainer.pos_label = 1
        self.explainer.pos_label = "Not survived"
        self.assertIsInstance(self.explainer.pos_label, int)
        self.assertIsInstance(self.explainer.pos_label_str, str)
        self.assertEqual(self.explainer.pos_label, 0)
        self.assertEqual(self.explainer.pos_label_str, "Not survived")

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas(), np.ndarray)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.get_precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.get_precision_df(multiclass=True),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.get_precision_df(quantiles=4),
                              pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.get_liftcurve_df(), pd.DataFrame)
Esempio n. 9
0
class MultiClassClassifierBunchTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_embarked()
        train_names, test_names = titanic_names()

        model = RandomForestClassifier(n_estimators=5, max_depth=2)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            cats=[{
                'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']
            }, 'Deck'],
            idxs=test_names,
            labels=['Queenstown', 'Southampton', 'Cherbourg'])

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_percentiles(self):
        self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray)

    def test_columns_ranked_by_shap(self):
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list)
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True),
                              list)

    def test_equivalent_col(self):
        self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Gender")
        self.assertEqual(self.explainer.equivalent_col("Gender"), "Sex_female")
        self.assertIsNone(self.explainer.equivalent_col("random"))

    def test_get_col(self):
        self.assertIsInstance(self.explainer.get_col("Gender"), pd.Series)
        self.assertEqual(self.explainer.get_col("Gender").dtype, "object")

        self.assertIsInstance(self.explainer.get_col("Age"), pd.Series)
        self.assertEqual(self.explainer.get_col("Age").dtype, np.float)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances,
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.permutation_importances_cats,
                              pd.DataFrame)

    def test_X_cats(self):
        self.assertIsInstance(self.explainer.X_cats, pd.DataFrame)

    def test_columns_cats(self):
        self.assertIsInstance(self.explainer.columns_cats, list)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics_descriptions(), dict)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame)

    def test_top_interactions(self):
        self.assertIsInstance(self.explainer.shap_top_interactions("Age"),
                              list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Age", topx=4), list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Age", cats=True), list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Gender", cats=True), list)

    def test_permutation_importances_df(self):
        self.assertIsInstance(self.explainer.permutation_importances_df(),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(topx=3), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cats=True), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cutoff=0.01),
            pd.DataFrame)

    def test_contrib_df(self):
        self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, topx=3),
                              pd.DataFrame)

    def test_contrib_summary_df(self):
        self.assertIsInstance(self.explainer.contrib_summary_df(0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3),
                              pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value,
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(
            self.explainer.shap_values.shape == (len(self.explainer),
                                                 len(self.explainer.columns)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.shap_values, np.ndarray)
        self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray)

    def test_shap_interaction_values(self):
        self.assertIsInstance(self.explainer.shap_interaction_values,
                              np.ndarray)
        self.assertIsInstance(self.explainer.shap_interaction_values_cats,
                              np.ndarray)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame)
        self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_shap_interaction_values_by_col(self):
        self.assertIsInstance(
            self.explainer.shap_interaction_values_by_col("Age"), np.ndarray)
        self.assertEqual(
            self.explainer.shap_interaction_values_by_col("Age").shape,
            self.explainer.shap_values.shape)
        self.assertEqual(
            self.explainer.shap_interaction_values_by_col("Age",
                                                          cats=True).shape,
            self.explainer.shap_values_cats.shape)

    def test_pdp_df(self):
        self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Age", index=0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Gender", index=0),
                              pd.DataFrame)

    def test_get_dfs(self):
        cols_df, shap_df, contribs_df = self.explainer.get_dfs()
        self.assertIsInstance(cols_df, pd.DataFrame)
        self.assertIsInstance(shap_df, pd.DataFrame)
        self.assertIsInstance(contribs_df, pd.DataFrame)

    def test_plot_importances(self):
        fig = self.explainer.plot_importances()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(kind='permutation')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_interactions(self):
        fig = self.explainer.plot_interactions("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions("Sex_female")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions("Gender")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_contributions(self):
        fig = self.explainer.plot_shap_contributions(0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, cats=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, cutoff=0.05)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_summary(self):
        fig = self.explainer.plot_shap_summary()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_interaction_summary(self):
        fig = self.explainer.plot_shap_interaction_summary("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Age", topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Sex_female",
                                                           topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Gender")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_dependence(self):
        fig = self.explainer.plot_shap_dependence("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", "Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Gender", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_interaction(self):
        fig = self.explainer.plot_shap_dependence("Age", "Sex_female")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Gender", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", "Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age",
                                                  "Sex_female",
                                                  highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pdp(self):
        fig = self.explainer.plot_pdp("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Gender", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", index=0)
        self.assertIsInstance(fig, go.Figure)

    def test_pos_label(self):
        self.explainer.pos_label = 1
        self.explainer.pos_label = "Southampton"
        self.assertIsInstance(self.explainer.pos_label, int)
        self.assertIsInstance(self.explainer.pos_label_str, str)
        self.assertEqual(self.explainer.pos_label, 1)
        self.assertEqual(self.explainer.pos_label_str, "Southampton")

    def test_get_prop_for_label(self):
        self.explainer.pos_label = 1
        tmp = self.explainer.pred_percentiles
        self.explainer.pos_label = 0
        self.assertTrue(
            np.alltrue(
                self.explainer.get_prop_for_label("pred_percentiles", 1) ==
                tmp))

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas, np.ndarray)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(multiclass=True),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(quantiles=4),
                              pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame)

    def test_prediction_result_markdown(self):
        self.assertIsInstance(self.explainer.prediction_result_markdown(0),
                              str)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_plot_precision(self):
        fig = self.explainer.plot_precision()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(multiclass=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_cumulative_precision(self):
        fig = self.explainer.plot_cumulative_precision()
        self.assertIsInstance(fig, go.Figure)

    def test_plot_confusion_matrix(self):
        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_classification(self):
        fig = self.explainer.plot_classification()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=1)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_roc_auc(self):
        fig = self.explainer.plot_roc_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(1.0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pr_auc(self):
        fig = self.explainer.plot_pr_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(1.0)
        self.assertIsInstance(fig, go.Figure)
Esempio n. 10
0
class ClassifierBaseExplainerTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = RandomForestClassifier(n_estimators=5, max_depth=2)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            roc_auc_score,
            cats=[{
                'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']
            }, 'Deck', 'Embarked'],
            target='Survival',
            labels=['Not survived', 'Survived'],
            idxs=test_names)

    def test_explainer_len(self):
        self.assertEqual(len(self.explainer), len(titanic_survive()[2]))

    def test_int_idx(self):
        self.assertEqual(self.explainer.get_int_idx(titanic_names()[1][0]), 0)

    def test_random_index(self):
        self.assertIsInstance(self.explainer.random_index(), int)
        self.assertIsInstance(self.explainer.random_index(return_str=True),
                              str)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_percentiles(self):
        self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray)

    def test_columns_ranked_by_shap(self):
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list)
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True),
                              list)

    def test_equivalent_col(self):
        self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Gender")
        self.assertEqual(self.explainer.equivalent_col("Gender"), "Sex_female")
        self.assertEqual(self.explainer.equivalent_col("Deck"), "Deck_A")
        self.assertEqual(self.explainer.equivalent_col("Deck_A"), "Deck")
        self.assertIsNone(self.explainer.equivalent_col("random"))

    def test_get_col(self):
        self.assertIsInstance(self.explainer.get_col("Gender"), pd.Series)
        self.assertEqual(self.explainer.get_col("Gender").dtype, "object")

        self.assertIsInstance(self.explainer.get_col("Deck"), pd.Series)
        self.assertEqual(self.explainer.get_col("Deck").dtype, "object")

        self.assertIsInstance(self.explainer.get_col("Age"), pd.Series)
        self.assertEqual(self.explainer.get_col("Age").dtype, np.float)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances,
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.permutation_importances_cats,
                              pd.DataFrame)

    def test_X_cats(self):
        self.assertIsInstance(self.explainer.X_cats, pd.DataFrame)

    def test_columns_cats(self):
        self.assertIsInstance(self.explainer.columns_cats, list)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics_markdown(), str)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame)

    def test_top_interactions(self):
        self.assertIsInstance(self.explainer.shap_top_interactions("Age"),
                              list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Age", topx=4), list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Age", cats=True), list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Gender", cats=True), list)

    def test_permutation_importances_df(self):
        self.assertIsInstance(self.explainer.permutation_importances_df(),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(topx=3), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cats=True), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cutoff=0.01),
            pd.DataFrame)

    def test_contrib_df(self):
        self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='high-to-low'),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='low-to-high'),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='importance'),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_df(X_row=self.explainer.X.iloc[[0]]),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_df(X_row=self.explainer.X_cats.iloc[[0]]),
            pd.DataFrame)

    def test_contrib_summary_df(self):
        self.assertIsInstance(self.explainer.contrib_summary_df(0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='low-to-high'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='high-to-low'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='importance'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(
                X_row=self.explainer.X.iloc[[0]]), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(
                X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value,
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(
            self.explainer.shap_values.shape == (len(self.explainer),
                                                 len(self.explainer.columns)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.shap_values, np.ndarray)
        self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray)

    def test_shap_interaction_values(self):
        self.assertIsInstance(self.explainer.shap_interaction_values,
                              np.ndarray)
        self.assertIsInstance(self.explainer.shap_interaction_values_cats,
                              np.ndarray)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame)
        self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_shap_interaction_values_by_col(self):
        self.assertIsInstance(
            self.explainer.shap_interaction_values_by_col("Age"), np.ndarray)
        self.assertEquals(
            self.explainer.shap_interaction_values_by_col("Age").shape,
            self.explainer.shap_values.shape)
        self.assertEquals(
            self.explainer.shap_interaction_values_by_col("Age",
                                                          cats=True).shape,
            self.explainer.shap_values_cats.shape)

    def test_pdp_result(self):
        self.assertIsInstance(self.explainer.get_pdp_result("Age"),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Gender"),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Age", index=0),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Gender", index=0),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(
            self.explainer.get_pdp_result("Age",
                                          X_row=self.explainer.X.iloc[[0]]),
            pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(
            self.explainer.get_pdp_result(
                "Age", X_row=self.explainer.X_cats.iloc[[0]]),
            pdpbox.pdp.PDPIsolate)

    def test_get_dfs(self):
        cols_df, shap_df, contribs_df = self.explainer.get_dfs()
        self.assertIsInstance(cols_df, pd.DataFrame)
        self.assertIsInstance(shap_df, pd.DataFrame)
        self.assertIsInstance(contribs_df, pd.DataFrame)

    def test_plot_importances(self):
        fig = self.explainer.plot_importances()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(kind='permutation')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_interactions(self):
        fig = self.explainer.plot_interactions("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions("Sex_female")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions("Gender")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_contributions(self):
        fig = self.explainer.plot_shap_contributions(0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, cats=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, cutoff=0.05)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='high-to-low')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='low-to-high')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='importance')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(
            X_row=self.explainer.X.iloc[[0]], sort='importance')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(
            X_row=self.explainer.X_cats.iloc[[0]], sort='importance')
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_summary(self):
        fig = self.explainer.plot_shap_summary()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_interaction_summary(self):
        fig = self.explainer.plot_shap_interaction_summary("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Age", topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Sex_female",
                                                           topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_interaction_summary("Gender")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_dependence(self):
        fig = self.explainer.plot_shap_dependence("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", "Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Gender", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_interaction(self):
        fig = self.explainer.plot_shap_dependence("Age", "Sex_female")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Gender", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", "Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age",
                                                  "Sex_female",
                                                  highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pdp(self):
        fig = self.explainer.plot_pdp("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Gender", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age",
                                      X_row=self.explainer.X_cats.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

    def test_yaml(self):
        yaml = self.explainer.to_yaml()
        self.assertIsInstance(yaml, str)
Esempio n. 11
0
class LogisticRegressionTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = LogisticRegression()
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            roc_auc_score,
            shap='linear',
            cats=['Sex', 'Cabin', 'Embarked'],
            labels=['Not survived', 'Survived'],
            idxs=test_names)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_percentiles(self):
        self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray)

    def test_columns_ranked_by_shap(self):
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list)
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True),
                              list)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances,
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.permutation_importances_cats,
                              pd.DataFrame)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics_markdown(), str)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame)

    def test_contrib_df(self):
        self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, topx=3),
                              pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value,
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(
            self.explainer.shap_values.shape == (len(self.explainer),
                                                 len(self.explainer.columns)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.shap_values, np.ndarray)
        self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame)
        self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties(include_interactions=False)

    def test_pdp_result(self):
        self.assertIsInstance(self.explainer.get_pdp_result("Age"),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Sex"),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Age", index=0),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Sex", index=0),
                              pdpbox.pdp.PDPIsolate)

    def test_pos_label(self):
        self.explainer.pos_label = 1
        self.explainer.pos_label = "Not survived"
        self.assertIsInstance(self.explainer.pos_label, int)
        self.assertIsInstance(self.explainer.pos_label_str, str)
        self.assertEquals(self.explainer.pos_label, 0)
        self.assertEquals(self.explainer.pos_label_str, "Not survived")

    def test_get_prop_for_label(self):
        self.explainer.pos_label = 1
        tmp = self.explainer.pred_percentiles
        self.explainer.pos_label = 0
        self.assertTrue(
            np.alltrue(
                self.explainer.get_prop_for_label("pred_percentiles", 1) ==
                tmp))

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas, np.ndarray)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(multiclass=True),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(quantiles=4),
                              pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame)

    def test_prediction_result_markdown(self):
        self.assertIsInstance(self.explainer.prediction_result_markdown(0),
                              str)
Esempio n. 12
0
class CatBoostClassifierTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = CatBoostClassifier(iterations=100, learning_rate=0.1, verbose=0)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
                            model, X_test, y_test, 
                            cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 
                                                'Deck', 'Embarked'],
                            labels=['Not survived', 'Survived'],
                            idxs=test_names)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas(), np.ndarray)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value(), (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(self.explainer.get_shap_values_df().shape == (len(self.explainer), len(self.explainer.merged_cols)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.get_shap_values_df(), pd.DataFrame)

    def test_shap_values_all_probabilities(self):
        self.assertTrue(self.explainer.shap_base_value() >= 0)
        self.assertTrue(self.explainer.shap_base_value() <= 1)
        self.assertTrue(np.all(self.explainer.get_shap_values_df().sum(axis=1) + self.explainer.shap_base_value() >= 0))
        self.assertTrue(np.all(self.explainer.get_shap_values_df().sum(axis=1) + self.explainer.shap_base_value() <= 1))

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.get_mean_abs_shap_df(), pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties(include_interactions=False)

    def test_pdp_df(self):
        self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Age", index=0), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Gender", index=0), pd.DataFrame)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.get_precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.get_precision_df(multiclass=True), pd.DataFrame)
        self.assertIsInstance(self.explainer.get_precision_df(quantiles=4), pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.get_liftcurve_df(), pd.DataFrame)

    def test_prediction_result_df(self):
        self.assertIsInstance(self.explainer.prediction_result_df(0), pd.DataFrame)



        
Esempio n. 13
0
class XGBCLassifierTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = XGBClassifier()
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            cats=[{
                'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']
            }, 'Deck', 'Embarked'],
            labels=['Not survived', 'Survived'])

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas, np.ndarray)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances,
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.permutation_importances_cats,
                              pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value,
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(
            self.explainer.shap_values.shape == (len(self.explainer),
                                                 len(self.explainer.columns)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.shap_values, np.ndarray)
        self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray)

    def test_shap_values_all_probabilities(self):
        self.assertTrue(self.explainer.shap_base_value >= 0)
        self.assertTrue(self.explainer.shap_base_value <= 1)
        self.assertTrue(
            np.all(
                self.explainer.shap_values.sum(axis=1) +
                self.explainer.shap_base_value >= 0))
        self.assertTrue(
            np.all(
                self.explainer.shap_values.sum(axis=1) +
                self.explainer.shap_base_value <= 1))

    # def test_shap_interaction_values(self):
    #     self.assertIsInstance(self.explainer.shap_interaction_values, np.ndarray)
    #     self.assertIsInstance(self.explainer.shap_interaction_values_cats, np.ndarray)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame)
        self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties(include_interactions=False)

    def test_pdp_result(self):
        self.assertIsInstance(self.explainer.get_pdp_result("Age"),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Gender"),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Age", index=0),
                              pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Gender", index=0),
                              pdpbox.pdp.PDPIsolate)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(multiclass=True),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(quantiles=4),
                              pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame)

    def test_prediction_result_markdown(self):
        self.assertIsInstance(self.explainer.prediction_result_markdown(0),
                              str)
class ClassifierExplainerTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = RandomForestClassifier(n_estimators=5, max_depth=2)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            cats=[{
                'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']
            }, 'Deck', 'Embarked'],
            cats_notencoded={'Gender': 'No Gender'},
            idxs=test_names,
            labels=['Not survived', 'Survived'])

    def test_pos_label(self):
        self.explainer.pos_label = 1
        self.explainer.pos_label = "Not survived"
        self.assertIsInstance(self.explainer.pos_label, int)
        self.assertIsInstance(self.explainer.pos_label_str, str)
        self.assertEqual(self.explainer.pos_label, 0)
        self.assertEqual(self.explainer.pos_label_str, "Not survived")

    def test_custom_metrics(self):
        def meandiff_metric1(y_true, y_pred):
            return np.mean(y_true) - np.mean(y_pred)

        def meandiff_metric2(y_true, y_pred, cutoff):
            return np.mean(y_true) - np.mean(np.where(y_pred > cutoff, 1, 0))

        def meandiff_metric3(y_true, y_pred, pos_label):
            return np.mean(np.where(y_true == pos_label, 1, 0)) - np.mean(
                y_pred[:, pos_label])

        def meandiff_metric4(y_true, y_pred, cutoff, pos_label):
            return np.mean(np.where(y_true == pos_label, 1, 0)) - np.mean(
                np.where(y_pred[:, pos_label] > cutoff, 1, 0))

        metrics = np.array(
            list(
                self.explainer.metrics(show_metrics=[
                    meandiff_metric1, meandiff_metric2, meandiff_metric3,
                    meandiff_metric4
                ]).values()))
        self.assertTrue(np.all(metrics == metrics[0]))

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas(), np.ndarray)
        self.assertIsInstance(self.explainer.pred_probas(1), np.ndarray)
        self.assertIsInstance(self.explainer.pred_probas("Survived"),
                              np.ndarray)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)
        self.assertIsInstance(self.explainer.metrics_descriptions(cutoff=0.9),
                              dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.get_precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.get_precision_df(multiclass=True),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.get_precision_df(quantiles=4),
                              pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.get_liftcurve_df(), pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_plot_precision(self):
        fig = self.explainer.plot_precision()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(multiclass=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_cumulutive_precision(self):
        fig = self.explainer.plot_cumulative_precision()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_cumulative_precision(percentile=0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_cumulative_precision(percentile=0.1,
                                                       pos_label=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_confusion_matrix(self):
        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(add_wizard=False, round=3)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_classification(self):
        fig = self.explainer.plot_classification()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=1)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_roc_auc(self):
        fig = self.explainer.plot_roc_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(1.0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pr_auc(self):
        fig = self.explainer.plot_pr_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(1.0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_prediction_result(self):
        fig = self.explainer.plot_prediction_result(0)
        self.assertIsInstance(fig, go.Figure)
class ClassifierBaseExplainerTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = RandomForestClassifier(n_estimators=5, max_depth=2)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            cats=[{
                'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']
            }, 'Deck', 'Embarked'],
            target='Survival',
            labels=['Not survived', 'Survived'],
            idxs=test_names)

    def test_explainer_len(self):
        self.assertEqual(len(self.explainer), len(titanic_survive()[2]))

    def test_int_idx(self):
        self.assertEqual(self.explainer.get_idx(titanic_names()[1][0]), 0)

    def test_random_index(self):
        self.assertIsInstance(self.explainer.random_index(), int)
        self.assertIsInstance(self.explainer.random_index(return_str=True),
                              str)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_row_from_input(self):
        input_row = self.explainer.get_row_from_input(
            self.explainer.X.iloc[[0]].values.tolist())
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(
            self.explainer.X_merged.iloc[[0]].values.tolist())
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(self.explainer.X_merged[
            self.explainer.columns_ranked_by_shap()].iloc[[0]].values.tolist(),
                                                      ranked_by_shap=True)
        self.assertIsInstance(input_row, pd.DataFrame)

    def test_pred_percentiles(self):
        self.assertIsInstance(self.explainer.pred_percentiles(), np.ndarray)

    def test_columns_ranked_by_shap(self):
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list)

    def test_get_col(self):
        self.assertIsInstance(self.explainer.get_col("Gender"), pd.Series)
        self.assertTrue(is_categorical_dtype(self.explainer.get_col("Gender")))

        self.assertIsInstance(self.explainer.get_col("Deck"), pd.Series)
        self.assertTrue(is_categorical_dtype(self.explainer.get_col("Deck")))

        self.assertIsInstance(self.explainer.get_col("Age"), pd.Series)
        self.assertTrue(is_numeric_dtype(self.explainer.get_col("Age")))

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances(),
                              pd.DataFrame)

    def test_X_cats(self):
        self.assertIsInstance(self.explainer.X_cats, pd.DataFrame)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame)

    def test_top_interactions(self):
        self.assertIsInstance(self.explainer.top_shap_interactions("Age"),
                              list)
        self.assertIsInstance(
            self.explainer.top_shap_interactions("Age", topx=4), list)

    def test_permutation_importances_df(self):
        self.assertIsInstance(self.explainer.get_permutation_importances_df(),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_permutation_importances_df(topx=3),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_permutation_importances_df(cutoff=0.01),
            pd.DataFrame)

    def test_contrib_df(self):
        self.assertIsInstance(self.explainer.get_contrib_df(0), pd.DataFrame)
        self.assertIsInstance(self.explainer.get_contrib_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_df(0, sort='high-to-low'), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_df(0, sort='low-to-high'), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_df(0, sort='importance'), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_df(X_row=self.explainer.X.iloc[[0]]),
            pd.DataFrame)

    def test_contrib_summary_df(self):
        self.assertIsInstance(self.explainer.get_contrib_summary_df(0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.get_contrib_summary_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_summary_df(0, round=3), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_summary_df(0, sort='low-to-high'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_summary_df(0, sort='high-to-low'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_summary_df(0, sort='importance'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.get_contrib_summary_df(
                X_row=self.explainer.X.iloc[[0]]), pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value(),
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(self.explainer.get_shap_values_df().shape == (
            len(self.explainer), len(self.explainer.merged_cols)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.get_shap_values_df(),
                              pd.DataFrame)

    def test_shap_interaction_values(self):
        self.assertIsInstance(self.explainer.shap_interaction_values(),
                              np.ndarray)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_shap_interaction_values_by_col(self):
        self.assertIsInstance(
            self.explainer.shap_interaction_values_for_col("Age"), np.ndarray)
        self.assertEqual(
            self.explainer.shap_interaction_values_for_col("Age").shape,
            self.explainer.get_shap_values_df().shape)

    def test_prediction_result_df(self):
        df = self.explainer.prediction_result_df(0)
        self.assertIsInstance(df, pd.DataFrame)

    def test_pdp_df(self):
        self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Age", index=0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Gender", index=0),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.pdp_df("Age", X_row=self.explainer.X.iloc[[0]]),
            pd.DataFrame)

    def test_memory_usage(self):
        self.assertIsInstance(self.explainer.memory_usage(), pd.DataFrame)
        self.assertIsInstance(self.explainer.memory_usage(cutoff=1000),
                              pd.DataFrame)

    def test_plot_importances(self):
        fig = self.explainer.plot_importances()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(kind='permutation')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(topx=3)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_interactions(self):
        fig = self.explainer.plot_interactions_importance("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions_importance("Gender")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_contributions(self):
        fig = self.explainer.plot_contributions(0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_contributions(0, topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_contributions(0, cutoff=0.05)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_contributions(0, sort='high-to-low')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_contributions(0, sort='low-to-high')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_contributions(0, sort='importance')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_contributions(
            X_row=self.explainer.X.iloc[[0]], sort='importance')
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_detailed(self):
        fig = self.explainer.plot_importances_detailed()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances_detailed(topx=3)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_interactions_detailed(self):
        fig = self.explainer.plot_interactions_detailed("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions_detailed("Age", topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions_detailed("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interactions_detailed("Gender")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_dependence(self):
        fig = self.explainer.plot_dependence("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_dependence("Age", "Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_dependence("Age", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_dependence("Gender", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_interaction(self):

        fig = self.explainer.plot_interaction("Gender", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_interaction("Age",
                                              "Gender",
                                              highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pdp(self):
        fig = self.explainer.plot_pdp("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Gender")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Gender", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

    def test_yaml(self):
        yaml = self.explainer.to_yaml()
        self.assertIsInstance(yaml, str)
Esempio n. 16
0
class ClassifierBunchTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = RandomForestClassifier(n_estimators=5, max_depth=2)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
            model,
            X_test,
            y_test,
            cats=[{
                'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']
            }, 'Deck', 'Embarked'],
            idxs=test_names,
            labels=['Not survived', 'Survived'])

    def test_pos_label(self):
        self.explainer.pos_label = 1
        self.explainer.pos_label = "Not survived"
        self.assertIsInstance(self.explainer.pos_label, int)
        self.assertIsInstance(self.explainer.pos_label_str, str)
        self.assertEquals(self.explainer.pos_label, 0)
        self.assertEquals(self.explainer.pos_label_str, "Not survived")

    def test_get_prop_for_label(self):
        self.explainer.pos_label = 1
        tmp = self.explainer.pred_percentiles
        self.explainer.pos_label = 0
        self.assertTrue(
            np.alltrue(
                self.explainer.get_prop_for_label("pred_percentiles", 1) ==
                tmp))

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas, np.ndarray)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(multiclass=True),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(quantiles=4),
                              pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame)

    def test_prediction_result_markdown(self):
        self.assertIsInstance(self.explainer.prediction_result_markdown(0),
                              str)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_plot_precision(self):
        fig = self.explainer.plot_precision()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(multiclass=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_cumulutive_precision(self):
        fig = self.explainer.plot_cumulative_precision()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_cumulative_precision(percentile=0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_cumulative_precision(percentile=0.1,
                                                       pos_label=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_confusion_matrix(self):
        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_classification(self):
        fig = self.explainer.plot_classification()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=1)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_roc_auc(self):
        fig = self.explainer.plot_roc_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(1.0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pr_auc(self):
        fig = self.explainer.plot_pr_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(1.0)
        self.assertIsInstance(fig, go.Figure)
class CatBoostClassifierTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = CatBoostClassifier(iterations=100, learning_rate=0.1, verbose=0)
        model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
                            model, X_test, y_test, roc_auc_score, 
                            shap='tree',
                            cats=['Sex', 'Cabin', 'Embarked'],
                            labels=['Not survived', 'Survived'],
                            idxs=test_names)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas, np.ndarray)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances, pd.DataFrame)
        self.assertIsInstance(self.explainer.permutation_importances_cats, pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value, (np.floating, float))

    def test_shap_values_all_probabilities(self):
        self.assertTrue(self.explainer.shap_base_value >= 0)
        self.assertTrue(self.explainer.shap_base_value <= 1)
        self.assertTrue(np.all(self.explainer.shap_values.sum(axis=1) + self.explainer.shap_base_value >= 0))
        self.assertTrue(np.all(self.explainer.shap_values.sum(axis=1) + self.explainer.shap_base_value <= 1))

    def test_shap_values_shape(self):
        self.assertTrue(self.explainer.shap_values.shape == (len(self.explainer), len(self.explainer.columns)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.shap_values, np.ndarray)
        self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray)

    # @unittest.expectedFailure
    # def test_shap_interaction_values(self):
    #     self.assertIsInstance(self.explainer.shap_interaction_values, np.ndarray)
    #     self.assertIsInstance(self.explainer.shap_interaction_values_cats, np.ndarray)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame)
        self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties(include_interactions=False)

    def test_pdp_result(self):
        self.assertIsInstance(self.explainer.get_pdp_result("Age"), pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Sex"), pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Age", index=0), pdpbox.pdp.PDPIsolate)
        self.assertIsInstance(self.explainer.get_pdp_result("Sex", index=0), pdpbox.pdp.PDPIsolate)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(multiclass=True), pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(quantiles=4), pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame)

    def test_prediction_result_markdown(self):
        self.assertIsInstance(self.explainer.prediction_result_markdown(0), str)