def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        self.model = RandomForestClassifier(n_estimators=5, max_depth=2)
        self.model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
                            self.model, X_test, y_test, 
                            cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 
                                                'Deck', 'Embarked'],
                            labels=['Not survived', 'Survived'])

        self.dashboard = ExplainerDashboard(self.explainer, 
            [
                ShapDependenceTab(self.explainer, title="Test Tab!"),
                ShapDependenceTab, 
                "importances"
            ], title="Test Title!")

        self.pkl_dir = Path.cwd() / "tests" / "cli_assets" 
        self.explainer.dump(self.pkl_dir / "explainer.joblib")
        self.explainer.to_yaml(self.pkl_dir / "explainer.yaml")
        self.dashboard.to_yaml(self.pkl_dir / "dashboard.yaml", 
                    explainerfile=str(self.pkl_dir / "explainer.joblib"))
def predict(model_path: Path, test_df_path: Path, predict_path: Path) -> None:
    """Recieves trained model and test dataset and persists predictions.

    Args:
        model_path (Path): Path to where trained model is.
        test_df_path (Path): Location of test dataframe.
        predict_path (Path): Path to store predictions.
    """
    # Read test dataset.
    test_df = pd.read_csv(test_df_path)
    x_test, y_test = test_df.drop("target", axis=1), test_df.loc[:, "target"]
    mask = (x_test.dtypes == "int64").tolist()
    x_test = x_test.loc[:, mask]

    # Read trained model and predict
    with open(model_path, "rb") as fd:
        model = pickle.load(fd)
    preds = model.predict_proba(x_test)
    explainer = ClassifierExplainer(
        model,
        x_test,
        y_test,
        descriptions=None,  # defaults to None
    )
    explainer.dump(predict_path / "explainer.joblib")

    # Persist predictions...
    with open(predict_path / "preds.pkl", "wb") as fd:
        pickle.dump(preds, fd, pickle.HIGHEST_PROTOCOL)
def generate_assets():
    X_train, y_train, X_test, y_test = titanic_survive()

    model = RandomForestClassifier(n_estimators=5, max_depth=2)
    model.fit(X_train, y_train)

    explainer = ClassifierExplainer(model,
                                    X_test,
                                    y_test,
                                    cats=[{
                                        'Gender':
                                        ['Sex_female', 'Sex_male', 'Sex_nan']
                                    }, 'Deck', 'Embarked'],
                                    labels=['Not survived', 'Survived'])

    dashboard = ExplainerDashboard(explainer, [
        ShapDependenceComposite(explainer, title="Test Tab!"),
        ShapDependenceComposite, "importances"
    ],
                                   title="Test Title!")

    pkl_dir = Path.cwd() / "tests" / "test_assets"
    explainer.to_yaml(pkl_dir / "explainer.yaml")
    dashboard.to_yaml(pkl_dir / "dashboard.yaml",
                      explainerfile=str(pkl_dir / "explainer.joblib"),
                      dump_explainer=True)
    return None
class DashboardTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        self.model = RandomForestClassifier(n_estimators=5, max_depth=2)
        self.model.fit(X_train, y_train)

        self.explainer = ClassifierExplainer(
                            self.model, X_test, y_test, 
                            cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 
                                                'Deck', 'Embarked'],
                            labels=['Not survived', 'Survived'])

        self.dashboard = ExplainerDashboard(self.explainer, 
            [
                ShapDependenceTab(self.explainer, title="Test Tab!"),
                ShapDependenceTab, 
                "importances"
            ], title="Test Title!")

        self.pkl_dir = Path.cwd() / "tests" / "cli_assets" 
        self.explainer.dump(self.pkl_dir / "explainer.joblib")
        self.explainer.to_yaml(self.pkl_dir / "explainer.yaml")
        self.dashboard.to_yaml(self.pkl_dir / "dashboard.yaml", 
                    explainerfile=str(self.pkl_dir / "explainer.joblib"))

    def test_yaml(self):
        yaml = self.dashboard.to_yaml()
        self.assertIsInstance(yaml, str)

    def test_yaml_dict(self):
        yaml_dict = self.dashboard.to_yaml(return_dict=True)
        self.assertIsInstance(yaml_dict, dict)
        self.assertIn("dashboard", yaml_dict)

    def test_load_config_joblib(self):
        db = ExplainerDashboard.from_config(
            self.pkl_dir / "explainer.joblib",
            self.pkl_dir / "dashboard.yaml")
        self.assertIsInstance(db, ExplainerDashboard)

    def test_load_config_yaml(self):
        db = ExplainerDashboard.from_config(
            self.pkl_dir / "dashboard.yaml")
        self.assertIsInstance(db, ExplainerDashboard)

    def test_load_config_explainer(self):
        db = ExplainerDashboard.from_config(
            self.explainer, self.pkl_dir / "dashboard.yaml")
        self.assertIsInstance(db, ExplainerDashboard)
        
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = CatBoostClassifier(iterations=100, verbose=0).fit(X_train, y_train)
        explainer = ClassifierExplainer(
                            model, X_test, y_test, 
                            cats=['Deck', 'Embarked'],
                            labels=['Not survived', 'Survived'])

        X_cats, y_cats = explainer.X_merged, explainer.y.astype("int")
        model = CatBoostClassifier(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[8, 9])
        self.explainer = ClassifierExplainer(model, X_cats, y_cats, 
                                cats=['Sex'], 
                                labels=['Not survived', 'Survived'],
                                idxs=X_test.index)
Beispiel #6
0
def build_explainer(explainer_config):
    if isinstance(explainer_config,
                  (Path, str)) and str(explainer_config).endswith(".yaml"):
        config = yaml.safe_load(open(str(explainer_config), "r"))
    elif isinstance(explainer_config, dict):
        config = explainer_config
    assert 'explainer' in config, \
        "Please pass a proper explainer.yaml config file that starts with `explainer:`!"
    config = explainer_config['explainer']

    print(f"explainerdashboard ===> Loading model from {config['modelfile']}")
    model = pickle.load(open(config['modelfile'], "rb"))

    print(f"explainerdashboard ===> Loading data from {config['datafile']}")
    if str(config['datafile']).endswith('.csv'):
        df = pd.read_csv(config['datafile'])
    elif str(config['datafile']).endswith('.parquet'):
        df = pd.read_parquet(config['datafile'])
    else:
        raise ValueError("datafile should either be a .csv or .parquet!")

    print(
        f"explainerdashboard ===> Using column {config['data_target']} to generate X, y "
    )
    target_col = config['data_target']
    X = df.drop(target_col, axis=1)
    y = df[target_col]

    if config['data_index'] is not None:
        print(
            f"explainerdashboard ===> Generating index from column {config['data_index']}"
        )
        assert config['data_index'] in X.columns, \
            (f"Cannot find data_index column ({config['data_index']})"
             f" in datafile ({config['datafile']})!"
              "Please set it to the proper index column name, or set it to null")
        X = X.set_index(config['data_index'])

    params = config['params']

    if config['explainer_type'] == "classifier":
        print(f"explainerdashboard ===> Generating ClassifierExplainer...")
        explainer = ClassifierExplainer(model, X, y, **params)
    elif config['explainer_type'] == "regression":
        print(f"explainerdashboard ===> Generating RegressionExplainer...")
        explainer = ClassifierExplainer(model, X, y, **params)
    return explainer
Beispiel #7
0
# Have to run on http://127.0.0.1:8050/ to work, to figure out if can work with streamlit or not.

# Import libraries here
from sklearn.ensemble import RandomForestClassifier

from explainerdashboard import ClassifierExplainer, ExplainerDashboard
from explainerdashboard.datasets import titanic_survive, feature_descriptions

X_train, y_train, X_test, y_test = titanic_survive()
model = RandomForestClassifier(n_estimators=50,
                               max_depth=10).fit(X_train, y_train)

explainer = ClassifierExplainer(model,
                                X_test,
                                y_test,
                                cats=['Sex', 'Deck', 'Embarked'],
                                descriptions=feature_descriptions,
                                labels=['Not survived', 'Survived'])

ExplainerDashboard(explainer).run()
Beispiel #8
0
class CatBoostClassifierTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_survive()
        train_names, test_names = titanic_names()

        model = CatBoostClassifier(iterations=100,
                                   verbose=0).fit(X_train, y_train)
        explainer = ClassifierExplainer(model,
                                        X_test,
                                        y_test,
                                        cats=['Deck', 'Embarked'],
                                        labels=['Not survived', 'Survived'])

        X_cats, y_cats = explainer.X_cats, explainer.y
        model = CatBoostClassifier(iterations=5,
                                   verbose=0).fit(X_cats,
                                                  y_cats,
                                                  cat_features=[8, 9])
        self.explainer = ClassifierExplainer(
            model,
            X_cats,
            y_cats,
            cats=['Sex'],
            labels=['Not survived', 'Survived'])

    def test_explainer_len(self):
        self.assertEqual(len(self.explainer), len(titanic_survive()[2]))

    def test_int_idx(self):
        self.assertEqual(self.explainer.get_int_idx(titanic_names()[1][0]), 0)

    def test_random_index(self):
        self.assertIsInstance(self.explainer.random_index(), int)
        self.assertIsInstance(self.explainer.random_index(return_str=True),
                              str)

    def test_ordered_cats(self):
        self.assertEqual(self.explainer.ordered_cats("Sex"),
                         ['Sex_female', 'Sex_male'])
        self.assertEqual(
            self.explainer.ordered_cats("Deck", topx=2, sort='alphabet'),
            ['Deck_A', 'Deck_B'])

        self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='freq'),
                              list)
        self.assertIsInstance(
            self.explainer.ordered_cats("Deck", topx=3, sort='freq'), list)
        self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='shap'),
                              list)
        self.assertIsInstance(
            self.explainer.ordered_cats("Deck", topx=3, sort='shap'), list)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_row_from_input(self):
        input_row = self.explainer.get_row_from_input(
            self.explainer.X.iloc[[0]].values.tolist())
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(
            self.explainer.X_cats.iloc[[0]].values.tolist())
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(
            self.explainer.X_cats[self.explainer.columns_ranked_by_shap(
                cats=True)].iloc[[0]].values.tolist(),
            ranked_by_shap=True)
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(
            self.explainer.X[self.explainer.columns_ranked_by_shap(
                cats=False)].iloc[[0]].values.tolist(),
            ranked_by_shap=True)
        self.assertIsInstance(input_row, pd.DataFrame)

    def test_pred_percentiles(self):
        self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray)

    def test_columns_ranked_by_shap(self):
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list)
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True),
                              list)

    def test_equivalent_col(self):
        self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Sex")
        self.assertEqual(self.explainer.equivalent_col("Sex"), "Sex_female")
        self.assertIsNone(self.explainer.equivalent_col("random"))

    def test_get_col(self):
        self.assertIsInstance(self.explainer.get_col("Sex"), pd.Series)
        self.assertEqual(self.explainer.get_col("Sex").dtype, "object")

        self.assertIsInstance(self.explainer.get_col("Deck"), pd.Series)
        self.assertEqual(self.explainer.get_col("Deck").dtype, "object")

        self.assertIsInstance(self.explainer.get_col("Age"), pd.Series)
        self.assertEqual(self.explainer.get_col("Age").dtype, np.float)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances,
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.permutation_importances_cats,
                              pd.DataFrame)

    def test_X_cats(self):
        self.assertIsInstance(self.explainer.X_cats, pd.DataFrame)

    def test_columns_cats(self):
        self.assertIsInstance(self.explainer.columns_cats, list)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame)

    def test_top_interactions(self):
        self.assertIsInstance(self.explainer.shap_top_interactions("Age"),
                              list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Age", topx=4), list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Age", cats=True), list)
        self.assertIsInstance(
            self.explainer.shap_top_interactions("Sex", cats=True), list)

    def test_permutation_importances_df(self):
        self.assertIsInstance(self.explainer.permutation_importances_df(),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(topx=3), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cats=True), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cutoff=0.01),
            pd.DataFrame)

    def test_contrib_df(self):
        self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='high-to-low'),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='low-to-high'),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='importance'),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_df(X_row=self.explainer.X.iloc[[0]]),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_df(X_row=self.explainer.X_cats.iloc[[0]]),
            pd.DataFrame)

    def test_contrib_summary_df(self):
        self.assertIsInstance(self.explainer.contrib_summary_df(0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='low-to-high'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='high-to-low'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='importance'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(
                X_row=self.explainer.X.iloc[[0]]), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(
                X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value,
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(
            self.explainer.shap_values.shape == (len(self.explainer),
                                                 len(self.explainer.columns)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.shap_values, np.ndarray)
        self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame)
        self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_prediction_result_df(self):
        df = self.explainer.prediction_result_df(0)
        self.assertIsInstance(df, pd.DataFrame)

    def test_pdp_df(self):
        self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Sex"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Age", index=0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Sex_male", index=0),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.pdp_df("Age", X_row=self.explainer.X.iloc[[0]]),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.pdp_df("Age",
                                  X_row=self.explainer.X_cats.iloc[[0]]),
            pd.DataFrame)

    def test_plot_importances(self):
        fig = self.explainer.plot_importances()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(kind='permutation')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_contributions(self):
        fig = self.explainer.plot_shap_contributions(0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, cats=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, cutoff=0.05)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='high-to-low')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='low-to-high')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='importance')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(
            X_row=self.explainer.X.iloc[[0]], sort='importance')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(
            X_row=self.explainer.X_cats.iloc[[0]], sort='importance')
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_summary(self):
        fig = self.explainer.plot_shap_summary()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_dependence(self):
        fig = self.explainer.plot_shap_dependence("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", "Sex")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="freq")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="shap")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", sort="freq")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", sort="shap")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pdp(self):
        fig = self.explainer.plot_pdp("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Sex")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Sex", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age",
                                      X_row=self.explainer.X_cats.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

    def test_yaml(self):
        yaml = self.explainer.to_yaml()
        self.assertIsInstance(yaml, str)

    def test_pos_label(self):
        self.explainer.pos_label = 1
        self.explainer.pos_label = "Not survived"
        self.assertIsInstance(self.explainer.pos_label, int)
        self.assertIsInstance(self.explainer.pos_label_str, str)
        self.assertEqual(self.explainer.pos_label, 0)
        self.assertEqual(self.explainer.pos_label_str, "Not survived")

    def test_get_prop_for_label(self):
        self.explainer.pos_label = 1
        tmp = self.explainer.pred_percentiles
        self.explainer.pos_label = 0
        self.assertTrue(
            np.alltrue(
                self.explainer.get_prop_for_label("pred_percentiles", 1) ==
                tmp))

    def test_pred_probas(self):
        self.assertIsInstance(self.explainer.pred_probas, np.ndarray)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict)
        self.assertIsInstance(self.explainer.metrics_descriptions(cutoff=0.9),
                              dict)

    def test_precision_df(self):
        self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(multiclass=True),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.precision_df(quantiles=4),
                              pd.DataFrame)

    def test_lift_curve_df(self):
        self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame)

    def test_prediction_result_markdown(self):
        self.assertIsInstance(self.explainer.prediction_result_markdown(0),
                              str)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_plot_precision(self):
        fig = self.explainer.plot_precision()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(multiclass=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_cumulative_precision(self):
        fig = self.explainer.plot_cumulative_precision()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_cumulative_precision(percentile=0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_cumulative_precision(percentile=0.1,
                                                       pos_label=0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_confusion_matrix(self):
        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=False,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_confusion_matrix(normalized=True,
                                                   binary=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_lift_curve(self):
        fig = self.explainer.plot_lift_curve()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_lift_curve(cutoff=0.5)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_classification(self):
        fig = self.explainer.plot_classification()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(percentage=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_classification(cutoff=1)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_roc_auc(self):
        fig = self.explainer.plot_roc_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_roc_auc(1.0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pr_auc(self):
        fig = self.explainer.plot_pr_auc(0.5)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(0.0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pr_auc(1.0)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_prediction_result(self):
        fig = self.explainer.plot_prediction_result(0)
        self.assertIsInstance(fig, go.Figure)
Beispiel #9
0
from pathlib import Path
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

from explainerdashboard import ClassifierExplainer, RegressionExplainer, ExplainerDashboard
from explainerdashboard.datasets import *

pkl_dir = Path.cwd() / "pkls"

# classifier
X_train, y_train, X_test, y_test = titanic_survive()
model = RandomForestClassifier(n_estimators=50,
                               max_depth=5).fit(X_train, y_train)
clas_explainer = ClassifierExplainer(model,
                                     X_test,
                                     y_test,
                                     cats=['Sex', 'Deck', 'Embarked'],
                                     descriptions=feature_descriptions,
                                     labels=['Not survived', 'Survived'])
_ = ExplainerDashboard(clas_explainer)
clas_explainer.dump(pkl_dir / "clas_explainer.joblib")

# regression
X_train, y_train, X_test, y_test = titanic_fare()
model = RandomForestRegressor(n_estimators=50,
                              max_depth=5).fit(X_train, y_train)
reg_explainer = RegressionExplainer(model,
                                    X_test,
                                    y_test,
                                    cats=['Sex', 'Deck', 'Embarked'],
                                    descriptions=feature_descriptions,
                                    units="$")