class RegressionBaseExplainerTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_fare() self.test_len = len(X_test) train_names, test_names = titanic_names() _, self.names = titanic_names() model = RandomForestRegressor(n_estimators=5, max_depth=2).fit(X_train, y_train) self.explainer = RegressionExplainer( model, X_test, y_test, cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'], cats_notencoded={'Gender':'No Gender'}, idxs=test_names, target='Fare', units='$') def test_explainer_len(self): self.assertEqual(len(self.explainer), self.test_len) def test_int_idx(self): self.assertEqual(self.explainer.get_idx(self.names[0]), 0) def test_random_index(self): self.assertIsInstance(self.explainer.random_index(), int) self.assertIsInstance(self.explainer.random_index(return_str=True), str) def test_index_exists(self): self.assertTrue(self.explainer.index_exists(0)) self.assertTrue(self.explainer.index_exists(self.explainer.idxs[0])) self.assertFalse(self.explainer.index_exists('bla')) def test_row_from_input(self): input_row = self.explainer.get_row_from_input( self.explainer.X.iloc[[0]].values.tolist()) self.assertIsInstance(input_row, pd.DataFrame) input_row = self.explainer.get_row_from_input( self.explainer.X_merged.iloc[[0]].values.tolist()) self.assertIsInstance(input_row, pd.DataFrame) input_row = self.explainer.get_row_from_input( self.explainer.X_merged [self.explainer.columns_ranked_by_shap()] .iloc[[0]].values.tolist(), ranked_by_shap=True) self.assertIsInstance(input_row, pd.DataFrame) def test_prediction_result_df(self): df = self.explainer.prediction_result_df(0) self.assertIsInstance(df, pd.DataFrame) def test_preds(self): self.assertIsInstance(self.explainer.preds, np.ndarray) def test_pred_percentiles(self): self.assertIsInstance(self.explainer.pred_percentiles(), np.ndarray) def test_columns_ranked_by_shap(self): self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) def test_get_col(self): self.assertIsInstance(self.explainer.get_col("Gender"), pd.Series) self.assertTrue(is_categorical_dtype(self.explainer.get_col("Gender"))) self.assertIsInstance(self.explainer.get_col("Age"), pd.Series) self.assertTrue(is_numeric_dtype(self.explainer.get_col("Age"))) def test_permutation_importances(self): self.assertIsInstance(self.explainer.permutation_importances(), pd.DataFrame) def test_X_cats(self): self.assertIsInstance(self.explainer.X_cats, pd.DataFrame) def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics_descriptions(), dict) def test_mean_abs_shap_df(self): self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) def test_top_interactions(self): self.assertIsInstance(self.explainer.top_shap_interactions("Age"), list) self.assertIsInstance(self.explainer.top_shap_interactions("Age", topx=4), list) def test_permutation_importances_df(self): self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame) self.assertIsInstance(self.explainer.get_permutation_importances_df(topx=3), pd.DataFrame) self.assertIsInstance(self.explainer.get_permutation_importances_df(cutoff=0.01), pd.DataFrame) def test_contrib_df(self): self.assertIsInstance(self.explainer.get_contrib_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_df(0, topx=3), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_df(0, sort='high-to-low'), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_df(0, sort='low-to-high'), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_df(0, sort='importance'), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) def test_contrib_summary_df(self): self.assertIsInstance(self.explainer.get_contrib_summary_df(0), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_summary_df(0, topx=3), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_summary_df(0, round=3), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_summary_df(0, sort='high-to-low'), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_summary_df(0, sort='low-to-high'), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_summary_df(0, sort='importance'), pd.DataFrame) self.assertIsInstance(self.explainer.get_contrib_summary_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value(), (np.floating, float)) def test_shap_values_shape(self): self.assertTrue(self.explainer.get_shap_values_df().shape == (len(self.explainer), len(self.explainer.merged_cols))) def test_shap_values(self): self.assertIsInstance(self.explainer.get_shap_values_df(), pd.DataFrame) def test_shap_interaction_values(self): self.assertIsInstance(self.explainer.shap_interaction_values(), np.ndarray) def test_mean_abs_shap(self): self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) def test_memory_usage(self): self.assertIsInstance(self.explainer.memory_usage(), pd.DataFrame) self.assertIsInstance(self.explainer.memory_usage(cutoff=1000), pd.DataFrame) def test_calculate_properties(self): self.explainer.calculate_properties() def test_shap_interaction_values_for_col(self): self.assertIsInstance(self.explainer.shap_interaction_values_for_col("Age"), np.ndarray) self.assertEqual(self.explainer.shap_interaction_values_for_col("Age").shape, self.explainer.get_shap_values_df().shape) def test_pdp_df(self): self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Age", index=0), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Gender", index=0), pd.DataFrame) def test_plot_importances(self): fig = self.explainer.plot_importances() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(kind='permutation') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances(topx=3) self.assertIsInstance(fig, go.Figure) def test_plot_interactions(self): fig = self.explainer.plot_interactions_importance("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions_importance("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions_importance("Gender") self.assertIsInstance(fig, go.Figure) def test_plot_shap_interactions(self): fig = self.explainer.plot_contributions(0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, cutoff=0.05) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='high-to-low') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='low-to-high') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='importance') self.assertIsInstance(fig, go.Figure) def test_plot_shap_detailed(self): fig = self.explainer.plot_importances_detailed() self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_importances_detailed(topx=3) self.assertIsInstance(fig, go.Figure) def test_plot_interactions_detailed(self): fig = self.explainer.plot_interactions_detailed("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interactions_detailed("Age", topx=3) self.assertIsInstance(fig, go.Figure) def test_plot_dependence(self): fig = self.explainer.plot_dependence("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_dependence("Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_dependence("Age", "Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_dependence("Age", highlight_index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_dependence("Gender", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_contributions(self): fig = self.explainer.plot_contributions(0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, topx=3) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='high-to-low') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='low-to-high') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(0, sort='importance') self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_contributions(X_row=self.explainer.X.iloc[[0]]) self.assertIsInstance(fig, go.Figure) def test_plot_interaction(self): fig = self.explainer.plot_interaction("Age", "Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interaction("Gender", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interaction("Gender", "Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_interaction("Age", "Gender", highlight_index=0) self.assertIsInstance(fig, go.Figure) def test_plot_pdp(self): fig = self.explainer.plot_pdp("Age") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender") self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Gender", index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", index=0) self.assertIsInstance(fig, go.Figure) fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]]) self.assertIsInstance(fig, go.Figure) def test_yaml(self): yaml = self.explainer.to_yaml() self.assertIsInstance(yaml, str)
class ExternalSourceRegressionTests(unittest.TestCase): def setUp(self): X_train, y_train, X_test, y_test = titanic_fare() model = RandomForestRegressor(n_estimators=50, max_depth=4).fit(X_train, y_train) X_test.reset_index(drop=True, inplace=True) X_test.index = X_test.index.astype(str) X_test1, y_test1 = X_test.iloc[:100], y_test.iloc[:100] X_test2, y_test2 = X_test.iloc[100:], y_test.iloc[100:] self.explainer = RegressionExplainer(model, X_test1, y_test1, cats=['Sex', 'Deck']) def index_exists_func(index): return index in X_test2.index def index_list_func(): # only returns first 50 indexes return list(X_test2.index[:50]) def y_func(index): idx = X_test2.index.get_loc(index) return y_test2.iloc[[idx]] def X_func(index): idx = X_test2.index.get_loc(index) return X_test2.iloc[[idx]] self.explainer.set_index_exists_func(index_exists_func) self.explainer.set_index_list_func(index_list_func) self.explainer.set_X_row_func(X_func) self.explainer.set_y_func(y_func) def test_get_X_row(self): self.assertIsInstance(self.explainer.get_X_row(0), pd.DataFrame) self.assertIsInstance(self.explainer.get_X_row("0"), pd.DataFrame) self.assertIsInstance(self.explainer.get_X_row("120"), pd.DataFrame) self.assertIsInstance(self.explainer.get_X_row("150"), pd.DataFrame) def test_get_shap_row(self): self.assertIsInstance(self.explainer.get_shap_row(0), pd.DataFrame) self.assertIsInstance(self.explainer.get_shap_row("0"), pd.DataFrame) self.assertIsInstance(self.explainer.get_shap_row("120"), pd.DataFrame) self.assertIsInstance(self.explainer.get_shap_row("150"), pd.DataFrame) def test_get_y(self): self.assertIsInstance(self.explainer.get_y(0), float) self.assertIsInstance(self.explainer.get_y("0"), float) self.assertIsInstance(self.explainer.get_y("120"), float) self.assertIsInstance(self.explainer.get_y("150"), float) def test_index_list(self): index_list = self.explainer.get_index_list() self.assertIn('100', index_list) self.assertNotIn('160', index_list) def test_index_exists(self): self.assertTrue(self.explainer.index_exists("0")) self.assertTrue(self.explainer.index_exists("100")) self.assertTrue(self.explainer.index_exists("160")) self.assertTrue(self.explainer.index_exists(0)) self.assertFalse(self.explainer.index_exists(-1)) self.assertFalse(self.explainer.index_exists(120)) self.assertFalse(self.explainer.index_exists("wrong index"))