def main(): df = pd.read_excel('data/mr_vs_fr_30.xlsx') df = df.sample(frac=1, random_state=seed) df['text_lemmatized'] = df['text'].apply(morphText) X_train, X_test, y_train, y_test = train_test_split( df['text_lemmatized'], df['label'], test_size=0.3, random_state=42, stratify=df['label']) flag_test = True get_pipe(X_train, y_train, flag_test, X_test, y_test) flag_test = False pipe = get_pipe(df['text_lemmatized'], df['label'], flag_test) k = 0 words = [] for index, row in df.iterrows(): te5 = TextExplainer(clf=DecisionTreeClassifier(max_depth=5), random_state=seed) te5.fit(row['text_lemmatized'], pipe.predict_proba) df_eli5_w = eli5.format_as_dataframe(te5.explain_weights()) print('class {}'.format('male' if row['label'] == 0 else 'woman')) print('predict:') print(df_eli5_w) print(100*'*') temp_m = ', '.join(df_eli5_w[df_eli5_w['weight'] > 0]['feature'].tolist()) if temp_m: words.append(temp_m) else: words.append('') k += 1 df['words'] = words df.to_excel('mr_vs_fr_words_30.xlsx', index=False)
def test_text_explainer_char_based(token_pattern): text = "Hello, world!" predict_proba = substring_presence_predict_proba('lo') te = TextExplainer(char_based=True, token_pattern=token_pattern) te.fit(text, predict_proba) print(te.metrics_) assert te.metrics_['score'] > 0.95 assert te.metrics_['mean_KL_divergence'] < 0.1 res = te.explain_prediction() format_as_all(res, te.clf_) check_targets_scores(res) assert res.targets[0].feature_weights.pos[0].feature == 'lo' # another way to look at results (not that useful for char ngrams) res = te.explain_weights() assert res.targets[0].feature_weights.pos[0].feature == 'lo'
def test_text_explainer_custom_classifier(): text = "foo-bar baz egg-spam" predict_proba = substring_presence_predict_proba('bar') # use decision tree to explain the prediction te = TextExplainer(clf=DecisionTreeClassifier(max_depth=2)) te.fit(text, predict_proba) print(te.metrics_) assert te.metrics_['score'] > 0.99 assert te.metrics_['mean_KL_divergence'] < 0.01 expl = te.explain_prediction() format_as_all(expl, te.clf_) # with explain_weights we can get a nice tree representation expl = te.explain_weights() print(expl.decision_tree.tree) assert expl.decision_tree.tree.feature_name == "bar" format_as_all(expl, te.clf_)