def get_interaction_values(self, n_samples_max=None, selection=None): """ Compute shap interaction values for each row of x_init. This function is only available for explainer of type TreeExplainer (used for tree based models). Please refer to the official tree shap paper for more information : https://arxiv.org/pdf/1802.03888.pdf Parameters ---------- n_samples_max : int, optional Limit the number of points for which we compute the interactions. selection : list, optional Contains list of index, subset of the input DataFrame that we want to plot Returns ------- np.ndarray Shap interaction values for each sample as an array of shape (# samples x # features x # features). """ x = copy.deepcopy(self.x_init) if selection: x = x.loc[selection] if hasattr(self, 'x_interaction'): if self.x_interaction.equals(x[:n_samples_max]): return self.interaction_values self.x_interaction = x[:n_samples_max] self.interaction_values = get_shap_interaction_values(self.x_interaction, self.explainer) return self.interaction_values
def test_shap_interaction_values_2(self): """ Unit test shap_interaction_values function """ for model in [ ske.RandomForestRegressor(n_estimators=1), ske.RandomForestClassifier(n_estimators=1) ]: self.x_df = self.x_df.astype(float) model.fit(self.x_df, self.y_df) print(model) explainer = shap.TreeExplainer(model) interaction_values = get_shap_interaction_values( self.x_df, explainer) assert interaction_values.shape[0] == self.x_df.shape[0] assert interaction_values.shape[1] == self.x_df.shape[1] assert interaction_values.shape[2] == self.x_df.shape[1]