def setUp(self):
        X_train, y_train, X_test, y_test = titanic_fare()
        self.test_len = len(X_test)

        train_names, test_names = titanic_names()
        _, self.names = titanic_names()

        model = CatBoostRegressor(iterations=5, verbose=0).fit(X_train, y_train)
        explainer = RegressionExplainer(model, X_test, y_test, cats=['Deck', 'Embarked'])
        X_cats, y_cats = explainer.X_merged, explainer.y
        model = CatBoostRegressor(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[8, 9])
        self.explainer = RegressionExplainer(model, X_cats, y_cats, cats=['Sex'], idxs=X_test.index)
Exemplo n.º 2
0
            def dashboard_exp(model, X_data, y_data):
                import dash_bootstrap_components as dbc

                from explainerdashboard import RegressionExplainer, ExplainerDashboard
                ExplainerDashboard(
                    RegressionExplainer(model, X_data, y_data),
                    bootstrap=dbc.themes.SANDSTONE,
                    importances=True,
                    model_summary=False,
                    contributions=True,
                    whatif=True,
                    shap_dependence=False,
                    shap_interaction=False,
                    decision_trees=False,
                    hide_whatifindexselector=True,
                    hide_whatifprediction=True,
                    hide_inputeditor=False,
                    hide_whatifcontributiongraph=False,
                    hide_whatifcontributiontable=True,
                    hide_whatifpdp=False,
                    hide_predindexselector=True,
                    hide_predictionsummary=True,
                    hide_contributiongraph=False,
                    hide_pdp=False,
                    hide_contributiontable=True,
                    hide_dropna=True,
                    hide_range=True,
                    hide_depth=True,
                    hide_sort=True,
                    hide_sample=True,  # hide sample size input on pdp component
                    hide_gridlines=True,  # hide gridlines on pdp component
                    hide_gridpoints=True,
                    hide_cats_sort=
                    True,  # hide the sorting option for categorical features
                    hide_cutoff=
                    True,  # hide cutoff selector on classification components
                    hide_percentage=
                    True,  # hide percentage toggle on classificaiton components
                    hide_log_x=
                    True,  # hide x-axis logs toggle on regression plots
                    hide_log_y=
                    True,  # hide y-axis logs toggle on regression plots
                    hide_ratio=True,  # hide the residuals type dropdown
                    hide_points=
                    True,  # hide the show violin scatter markers toggle
                    hide_winsor=True,  # hide the winsorize input
                    hide_wizard=
                    True,  # hide the wizard toggle in lift curve component
                    hide_star_explanation=True,
                ).run()
Exemplo n.º 3
0
if days_since_update > 7:

    model = pickle.load(open(MODELS_DIR / 'general_model.pkl', 'rb'))
    y = pd.read_csv(DATA_DIR / 'general_target.csv',
                    index_col=['Ticker']).drop(columns=['Date'])
    X = pd.read_csv(DATA_DIR / f'general_features.csv',
                    index_col=['Ticker']).drop(columns=['Date'])

    # Dashboard Explainer is fussy about Column Names
    X.columns = X.columns.str.replace('.', '')
    feature_names = model.get_booster().feature_names
    feature_names = [x.replace('.', '') for x in feature_names]
    model.get_booster().feature_names = feature_names

    explainer = RegressionExplainer(model, X, y)

    db = ExplainerDashboard(
        explainer,
        title="Stock Valuation Explainer",
        description=
        "Visit https://share.streamlit.io/gardnmi/fundamental-stock-prediction to see the model in use,",
        shap_interaction=False,
        precision='float32',
        decision_trees=False)

    db.to_yaml("dashboard.yaml",
               explainerfile="explainer.joblib",
               dump_explainer=True)

db = ExplainerDashboard.from_config("dashboard.yaml")
Exemplo n.º 4
0
from unittest.mock import MagicMock
import sys
sys.modules["xgboost"] = MagicMock()

from explainerdashboard import RegressionExplainer, ExplainerDashboard

explainer = RegressionExplainer.from_file("explainer.joblib")
# you can override params during load from_config:
db = ExplainerDashboard.from_config(explainer, "dashboard.yaml", title="Test")

app = db.flask_server()

# run waitress-serve --port=8070 dashboard:app in command line
Exemplo n.º 5
0
class CatBoostRegressionTests(unittest.TestCase):
    def setUp(self):
        X_train, y_train, X_test, y_test = titanic_fare()
        self.test_len = len(X_test)

        train_names, test_names = titanic_names()
        _, self.names = titanic_names()

        model = CatBoostRegressor(iterations=5,
                                  verbose=0).fit(X_train, y_train)
        explainer = RegressionExplainer(model,
                                        X_test,
                                        y_test,
                                        cats=['Deck', 'Embarked'])
        X_cats, y_cats = explainer.X_cats, explainer.y
        model = CatBoostRegressor(iterations=5,
                                  verbose=0).fit(X_cats,
                                                 y_cats,
                                                 cat_features=[8, 9])
        self.explainer = RegressionExplainer(model,
                                             X_cats,
                                             y_cats,
                                             cats=['Sex'])

    def test_explainer_len(self):
        self.assertEqual(len(self.explainer), self.test_len)

    def test_int_idx(self):
        self.assertEqual(self.explainer.get_int_idx(self.names[0]), 0)

    def test_random_index(self):
        self.assertIsInstance(self.explainer.random_index(), int)
        self.assertIsInstance(self.explainer.random_index(return_str=True),
                              str)

    def test_row_from_input(self):
        input_row = self.explainer.get_row_from_input(
            self.explainer.X.iloc[[0]].values.tolist())
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(
            self.explainer.X_cats.iloc[[0]].values.tolist())
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(
            self.explainer.X_cats[self.explainer.columns_ranked_by_shap(
                cats=True)].iloc[[0]].values.tolist(),
            ranked_by_shap=True)
        self.assertIsInstance(input_row, pd.DataFrame)

        input_row = self.explainer.get_row_from_input(
            self.explainer.X[self.explainer.columns_ranked_by_shap(
                cats=False)].iloc[[0]].values.tolist(),
            ranked_by_shap=True)
        self.assertIsInstance(input_row, pd.DataFrame)

    def test_prediction_result_df(self):
        df = self.explainer.prediction_result_df(0)
        self.assertIsInstance(df, pd.DataFrame)

    def test_preds(self):
        self.assertIsInstance(self.explainer.preds, np.ndarray)

    def test_pred_percentiles(self):
        self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray)

    def test_columns_ranked_by_shap(self):
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list)
        self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True),
                              list)

    def test_equivalent_col(self):
        self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Sex")
        self.assertEqual(self.explainer.equivalent_col("Sex"), "Sex_female")
        self.assertIsNone(self.explainer.equivalent_col("random"))

    def test_ordered_cats(self):
        self.assertEqual(self.explainer.ordered_cats("Sex"),
                         ['Sex_female', 'Sex_male'])
        self.assertEqual(
            self.explainer.ordered_cats("Deck", topx=2, sort='alphabet'),
            ['Deck_A', 'Deck_B'])

        self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='freq'),
                              list)
        self.assertIsInstance(
            self.explainer.ordered_cats("Deck", topx=3, sort='freq'), list)
        self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='shap'),
                              list)
        self.assertIsInstance(
            self.explainer.ordered_cats("Deck", topx=3, sort='shap'), list)

    def test_get_col(self):
        self.assertIsInstance(self.explainer.get_col("Sex"), pd.Series)
        self.assertEqual(self.explainer.get_col("Sex").dtype, "object")

        self.assertIsInstance(self.explainer.get_col("Age"), pd.Series)
        self.assertEqual(self.explainer.get_col("Age").dtype, np.float)

    def test_permutation_importances(self):
        self.assertIsInstance(self.explainer.permutation_importances,
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.permutation_importances_cats,
                              pd.DataFrame)

    def test_X_cats(self):
        self.assertIsInstance(self.explainer.X_cats, pd.DataFrame)

    def test_columns_cats(self):
        self.assertIsInstance(self.explainer.columns_cats, list)

    def test_metrics(self):
        self.assertIsInstance(self.explainer.metrics(), dict)
        self.assertIsInstance(self.explainer.metrics_descriptions(), dict)

    def test_mean_abs_shap_df(self):
        self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame)

    def test_permutation_importances_df(self):
        self.assertIsInstance(self.explainer.permutation_importances_df(),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(topx=3), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cats=True), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.permutation_importances_df(cutoff=0.01),
            pd.DataFrame)

    def test_contrib_df(self):
        self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='high-to-low'),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='low-to-high'),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_df(0, sort='importance'),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_df(X_row=self.explainer.X.iloc[[0]]),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_df(X_row=self.explainer.X_cats.iloc[[0]]),
            pd.DataFrame)

    def test_contrib_summary_df(self):
        self.assertIsInstance(self.explainer.contrib_summary_df(0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3),
                              pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='high-to-low'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='low-to-high'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(0, sort='importance'),
            pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(
                X_row=self.explainer.X.iloc[[0]]), pd.DataFrame)
        self.assertIsInstance(
            self.explainer.contrib_summary_df(
                X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame)

    def test_shap_base_value(self):
        self.assertIsInstance(self.explainer.shap_base_value,
                              (np.floating, float))

    def test_shap_values_shape(self):
        self.assertTrue(
            self.explainer.shap_values.shape == (len(self.explainer),
                                                 len(self.explainer.columns)))

    def test_shap_values(self):
        self.assertIsInstance(self.explainer.shap_values, np.ndarray)
        self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray)

    def test_mean_abs_shap(self):
        self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame)
        self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame)

    def test_calculate_properties(self):
        self.explainer.calculate_properties()

    def test_pdp_df(self):
        self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Sex"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Age", index=0),
                              pd.DataFrame)
        self.assertIsInstance(self.explainer.pdp_df("Sex", index=0),
                              pd.DataFrame)

    def test_get_dfs(self):
        cols_df, shap_df, contribs_df = self.explainer.get_dfs()
        self.assertIsInstance(cols_df, pd.DataFrame)
        self.assertIsInstance(shap_df, pd.DataFrame)
        self.assertIsInstance(contribs_df, pd.DataFrame)

    def test_plot_importances(self):
        fig = self.explainer.plot_importances()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(kind='permutation')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_importances(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_summary(self):
        fig = self.explainer.plot_shap_summary()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_summary(cats=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_dependence(self):
        fig = self.explainer.plot_shap_dependence("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", "Sex")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex_female", "Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Age", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Sex", highlight_index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="freq")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="shap")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", sort="freq")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_dependence("Deck", sort="shap")
        self.assertIsInstance(fig, go.Figure)

    def test_plot_shap_contributions(self):
        fig = self.explainer.plot_shap_contributions(0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, cats=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, topx=3)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='high-to-low')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='low-to-high')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(0, sort='importance')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(
            X_row=self.explainer.X.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_shap_contributions(
            X_row=self.explainer.X_cats.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

    def test_plot_pdp(self):
        fig = self.explainer.plot_pdp("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Sex")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Sex", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", index=0)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_pdp("Age",
                                      X_row=self.explainer.X_cats.iloc[[0]])
        self.assertIsInstance(fig, go.Figure)

    def test_yaml(self):
        yaml = self.explainer.to_yaml()
        self.assertIsInstance(yaml, str)

    def test_residuals(self):
        self.assertIsInstance(self.explainer.residuals, pd.Series)

    def test_prediction_result_markdown(self):
        result_index = self.explainer.prediction_result_markdown(0)
        self.assertIsInstance(result_index, str)
        result_name = self.explainer.prediction_result_markdown(self.names[0])
        self.assertIsInstance(result_name, str)

    def test_metrics(self):
        metrics_dict = self.explainer.metrics()
        self.assertIsInstance(metrics_dict, dict)
        self.assertTrue('root_mean_squared_error' in metrics_dict)
        self.assertTrue('mean_absolute_error' in metrics_dict)
        self.assertTrue('R-squared' in metrics_dict)
        self.assertIsInstance(self.explainer.metrics_descriptions(), dict)

    def test_plot_predicted_vs_actual(self):
        fig = self.explainer.plot_predicted_vs_actual(logs=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_predicted_vs_actual(logs=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_predicted_vs_actual(log_x=True, log_y=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_residuals(self):
        fig = self.explainer.plot_residuals()
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals(vs_actual=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals(residuals='ratio')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals(residuals='log-ratio')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals(residuals='log-ratio',
                                            vs_actual=True)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_residuals_vs_feature(self):
        fig = self.explainer.plot_residuals_vs_feature("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals_vs_feature("Age",
                                                       residuals='log-ratio')
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals_vs_feature("Age", dropna=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals_vs_feature("Sex", points=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_residuals_vs_feature("Sex", winsor=10)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_y_vs_feature(self):
        fig = self.explainer.plot_y_vs_feature("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_y_vs_feature("Age", dropna=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_y_vs_feature("Sex", points=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_y_vs_feature("Sex", winsor=10)
        self.assertIsInstance(fig, go.Figure)

    def test_plot_preds_vs_feature(self):
        fig = self.explainer.plot_preds_vs_feature("Age")
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_preds_vs_feature("Age", dropna=True)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_preds_vs_feature("Sex", points=False)
        self.assertIsInstance(fig, go.Figure)

        fig = self.explainer.plot_preds_vs_feature("Sex", winsor=10)
        self.assertIsInstance(fig, go.Figure)
Exemplo n.º 6
0
from explainerdashboard import RegressionExplainer, ExplainerDashboard
from explainerdashboard.custom import *
from joblib import load
import pandas as pd
from sklearn.model_selection import train_test_split

# Load model & data
dec_tree = load('dec_tree_v3.joblib')
df_data = pd.read_csv('data_v3_enc.csv')

# Prepare & split data
X = df_data.drop(['Duration', 'Timestamp'], axis=1)
y = df_data.Duration
Xt, X_small, yt, y_small = train_test_split(X, y, test_size=0.01, random_state=0)

exp = RegressionExplainer(dec_tree, X_small, y_small, cats=['Day_of_week', 'Hour', 'Vehicle', 'Position'])

# Build
db = ExplainerDashboard(exp, [ShapDependenceComposite, WhatIfComposite], hide_whatifpdp=True)

# Save
exp.dump("explainer.joblib")
db.to_yaml("dashboard.yaml")
Exemplo n.º 7
0
                         n_jobs=8,
                         num_parallel_tree=1,
                         objective='reg:squarederror',
                         random_state=42,
                         reg_alpha=1,
                         reg_lambda=0,
                         scale_pos_weight=1.0,
                         seed=42,
                         subsample=0.6000000000000001,
                         tree_method='exact',
                         validate_parameters=1,
                         verbosity=None)

X, Y = pd.read_csv('X.csv', index_col=0), pd.read_csv('Y.csv', index_col=0)

REmodel = model.fit(X, Y)
explainer = RegressionExplainer(REmodel,
                                X,
                                Y,
                                cats=cats,
                                descriptions=feature_descriptions,
                                units="$")

ExplainerDashboard(
    explainer,
    title='XGBoost Regression Model Explainer: Predicting House Prices',
    description=
    'This dashboard shows the inner workings of a fitted machine learning model, and explains its predictions.',
    shap_interaction=False,
    decision_trees=False).run(port=int(os.environ.get('PORT', 5000)))
Exemplo n.º 8
0
clas_explainer = ClassifierExplainer(model,
                                     X_test,
                                     y_test,
                                     cats=['Sex', 'Deck', 'Embarked'],
                                     descriptions=feature_descriptions,
                                     labels=['Not survived', 'Survived'])
_ = ExplainerDashboard(clas_explainer)
clas_explainer.dump(pkl_dir / "clas_explainer.joblib")

# regression
X_train, y_train, X_test, y_test = titanic_fare()
model = RandomForestRegressor(n_estimators=50,
                              max_depth=5).fit(X_train, y_train)
reg_explainer = RegressionExplainer(model,
                                    X_test,
                                    y_test,
                                    cats=['Sex', 'Deck', 'Embarked'],
                                    descriptions=feature_descriptions,
                                    units="$")
_ = ExplainerDashboard(reg_explainer)
reg_explainer.dump(pkl_dir / "reg_explainer.joblib")

# multiclass
X_train, y_train, X_test, y_test = titanic_embarked()
model = RandomForestClassifier(n_estimators=50,
                               max_depth=5).fit(X_train, y_train)
multi_explainer = ClassifierExplainer(
    model,
    X_test,
    y_test,
    cats=['Sex', 'Deck'],
    descriptions=feature_descriptions,
Exemplo n.º 9
0
from explainerdashboard.explainers import RandomForestRegressionExplainer
from sklearn.ensemble import RandomForestRegressor
from sklearn import tree
from explainerdashboard import ClassifierExplainer, ExplainerDashboard, RegressionExplainer
from explainerdashboard.datasets import titanic_survive, feature_descriptions
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np


imdb = pd.read_csv("Amélioration/Data/movies2.csv", encoding="latin-1")
imdb = imdb.rename(columns = {'11:14' :'film','7.2' : 'metascore','Crime' : 'genre', 'Greg Marcks':'realisateur', 	'Henry Thomas' : 'acteur_1', 'Colin Hanks':'acteur_2' ,'6000000' : 'budget',	'0': 'votes2',	'0.1': 'vote',	'Aug 12, 2005': 'date'})
colonne = ['genre','acteur_1', 'acteur_2', 'realisateur']
imdb = pd.get_dummies(imdb, columns= colonne)
imdb = imdb.drop(columns= ['film', 'budget', 'votes2', 'vote', 'date'], axis = 1 )


X = imdb.loc[:, imdb.columns != 'metascore' ]
y = imdb.loc[:, imdb.columns == 'metascore' ]

X_train, X_test, y_train, y_test = train_test_split(X,y, test_size = 0.3, random_state=42)

model = RandomForestRegressor().fit(X_train, y_train)
explainer = RegressionExplainer(model, X_test, y_test)

db = ExplainerDashboard(explainer, title="Metascore de film",
                    whatif=False, # you can switch off tabs with bools
                    shap_interaction=False,
                    decision_trees=False)

ExplainerDashboard(explainer).run()