def plot_k_graph(): with open(os.path.join(get_repo_path(), '_evaluation', 'means.pkl'), 'rb') as f: means = pickle.load(f) fig = plt.figure() for i, metric in enumerate(['cosine', 'dice', 'jaccard', 'overlap']): # , 'raw']: graph = fig.add_subplot(2, 2, i + 1) xs = [] ys = [] for k in [3, 5, 7]: for v in means[metric][k]: xs.append(k) ys.append(v) graph.set_title(metric) graph.set_xlabel('$k$') graph.set_ylabel('similarity mean') graph.scatter(xs, ys) if i + 1 in [1, 2]: graph.axes.xaxis.set_visible(False) if i + 1 in [2, 4]: graph.axes.yaxis.set_visible(False) current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', f'plot_k_graph_{metric}.png'))
def plot_k_box(): with open(os.path.join(get_repo_path(), '_evaluation', 'means.pkl'), 'rb') as f: means = pickle.load(f) for metric in ['cosine', 'dice', 'jaccard', 'overlap', 'raw']: # plt.title(metric) plt.xlabel('$k$') plt.ylabel('similarity mean') plt.boxplot(means[metric].values()) plt.xticks([1, 2, 3], [3, 5, 7]) current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', f'plot_k_{metric}_box.png'))
def get_fastbert_model(experiment): data_path = os.path.join(get_repo_path(), '_data', 'as_csv', experiment.name) databunch = BertDataBunch(data_path, data_path, tokenizer='bert-base-uncased', train_file='train.csv', val_file='val.csv', label_file='labels.csv', text_col='text', label_col='label', batch_size_per_gpu=8, max_seq_length=512, multi_gpu=True, multi_label=False, model_type='bert') fastbert = BertLearner.from_pretrained_model( databunch, pretrained_path=os.path.join(get_experiment_path(experiment), 'models', 'fastbert'), metrics=[{ 'name': 'accuracy', 'function': accuracy }], device=torch.device("cuda"), logger=logging.getLogger(), output_dir='output') return fastbert
def _load_raw_test_tweets(self): path = os.path.join(get_repo_path(), '_experiments', self.experiment.name) with open(os.path.join(path, 'used_data', 'X_test_raw.pkl'), 'rb') as f: tweets = pickle.load(f) return tweets
def plot(self): a = pd.DataFrame(self.all_data) group_size = [len(a[a['of_id'].isnull()]), len(a[a['of_id'].notnull()])] subgroup_size = [len(a[a['of_id'].isnull()][a['sexist'] == True]), len(a[a['of_id'].isnull()][a['sexist'] == False]), len(a[a['of_id'].notnull()])] group_names = [f'Originals\n{group_size[0]}', f'Adversarial Examples\n{group_size[1]}'] subgroup_names = [f'sexist\n{subgroup_size[0]}', f'non-sexist\n{subgroup_size[1]}', f'non-sexist\n{subgroup_size[2]}'] # Create colors b, r, g, y = [plt.cm.Blues, plt.cm.Reds, plt.cm.Greens, plt.cm.YlOrBr] # First Ring (outside) fig, ax = plt.subplots() ax.axis('equal') mypie, _ = ax.pie(group_size, radius=1.3, labels=group_names, colors=[b(0.6), y(0.2)]) plt.setp(mypie, width=0.3, edgecolor='white') # Second Ring (Inside) mypie2, _ = ax.pie(subgroup_size, radius=1.3 - 0.3, labels=subgroup_names, labeldistance=0.7, colors=[r(0.5), g(0.4), g(0.4)]) plt.setp(mypie2, width=0.4, edgecolor='white') plt.margins(0, 0) # show it current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', 'unsex_data.png'))
def _load(self, model_name): path = os.path.join(get_repo_path(), '_experiments', self.experiment.name, 'models', f'{model_name}.pkl') if 'svm' in model_name: model = pickle.load(open(path, 'rb'), encoding='latin1') else: model = pickle.load(open(path, 'rb')) return model
def plot_model_performance(f): metric = None if 'accuracy' in f.__name__: metric = 'Accuracy' elif 'f1' in f.__name__: metric = 'F1-Score' models = ['lr', 'svm', 'xgboost', 'fast-bert'] experiment_names = ['TOTO', 'TMTO', 'TOTM', 'TMTM'] all_f1s = np.array([0] * len(experiment_names)) for model in models: values = [f(model, experiment) for experiment in Experiments] all_f1s = all_f1s + np.array(values) plt.plot(experiment_names, values, label=model) for x, y in zip(experiment_names, values): plt.annotate( "{:.2f}".format(y), # this is the text (x, y), # this is the point to label textcoords="offset points", # how to position the text xytext=(0, 10), # distance from text to points (x,y) ha='center' ) # horizontal alignment can be left, right or center plt.plot(experiment_names, all_f1s / len(models), label='mean') plt.legend() plt.xlabel('$experiment$') plt.ylabel(metric) current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join( get_repo_path(), '_evaluation', 'graphs', f'model_performance_{metric.lower().replace("-", "_")}.png'))
def plot_f1_to_similarity(): with open( os.path.join(get_repo_path(), '_evaluation', 'similarity_backup.pkl'), 'rb') as f: backup = pickle.load(f) xs = [] ys = [] for model in ['lr', 'svm', 'xgboost']: for experiment in Experiments: f1 = get_f1_score(model, experiment) all_values_lime_shap = [] all_values_lime_builtin = [] all_values_shap_builtin = [] for k in [3, 5, 7]: for coef in ['cosine', 'dice', 'jaccard', 'overlap', 'raw']: all_values_lime_shap.append( backup[model][coef][k][experiment]['values_lime_shap']) all_values_lime_builtin.append( backup[model][coef][k][experiment] ['values_lime_builtin']) all_values_shap_builtin.append( backup[model][coef][k][experiment] ['values_shap_builtin']) for sim_values in [ all_values_lime_shap, all_values_lime_builtin, all_values_shap_builtin ]: xs.append(f1) ys.append((np.nanmean(sim_values))) pearson_pvalue = format(pearsonr(xs, ys)[1], '.2f') spearman_pvalue = format(spearmanr(xs, ys).pvalue, '.2f') # print('Pearson: ', pearson_pvalue) # print('Spearman: ', spearmanr(xs, ys)) plt.text(0.73, 0.12, f'Pearson: p={pearson_pvalue}\nSpearman: p={spearman_pvalue}', fontsize=10) plt.xlabel('F1-Score') plt.ylabel('Similarity Mean') plt.scatter(xs, ys) current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', 'plot_f1_to_similarity.png'))
def plot_source(): values = [678, 678, 678, 1280, 764] labels = [f'{l}\n{v}' for l, v in zip(['benevolent', 'hostile', 'other', 'callme', 'scales'], values)] plt.pie(values, labels=labels) # show it current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', 'unsex_data_source.png'))
def amount_of_explainable_tweets(): labels = [e.name for e in Experiments] values = [get_amount_of_explainable_tweets(e) for e in Experiments] plt.barh(labels, values, color='green') current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', 'amount_of_explainable_tweets.png'))
def plot_per_experiment(): for experiment in Experiments: group_size = [3262, 816] group_names = [f'Training\n{group_size[0]}', f'Test\n{group_size[1]}'] subgroup_size = None subgroup_names = None # Create colors b, y, p = [plt.cm.Blues, plt.cm.YlOrBr, plt.cm.Purples] # First Ring (outside) fig, ax = plt.subplots() ax.axis('equal') mypie, _ = ax.pie(group_size, radius=1.3, labels=group_names, colors=[p(0.6), p(0.4)]) plt.setp(mypie, width=0.3, edgecolor='white') if experiment == Experiments.Train_Orig_Test_Orig: # train: 3262 orig # test: 816 orig subgroup_size = [3262, 816] subgroup_names = [f'Originals\n{subgroup_size[0]}', f'Originals\n{subgroup_size[1]}'] # Second Ring (Inside) mypie2, _ = ax.pie(subgroup_size, radius=1.3 - 0.3, labels=subgroup_names, labeldistance=0.7, colors=[b(0.6), b(0.6)]) elif experiment == Experiments.Train_Orig_Test_Mixed: # train: 3262 orig # test: 530 orig + 286 mods subgroup_size = [3262, 530, 286] subgroup_names = [f'Originals\n{subgroup_size[0]}', f'Originals\n{subgroup_size[1]}', f'Adversarial Examples\n{subgroup_size[2]}'] mypie2, _ = ax.pie(subgroup_size, radius=1.3 - 0.3, labels=subgroup_names, labeldistance=0.7, colors=[b(0.6), b(0.6), y(0.2)]) elif experiment == Experiments.Train_Mixed_Test_Orig: # train: 2120 orig + 1142 mods # test: 816 orig subgroup_size = [2120, 1142, 816] subgroup_names = [f'Originals\n{subgroup_size[0]}', f'Adversarial Examples\n{subgroup_size[1]}', f'Originals\n{subgroup_size[2]}'] mypie2, _ = ax.pie(subgroup_size, radius=1.3 - 0.3, labels=subgroup_names, labeldistance=0.7, colors=[b(0.6), y(0.2), b(0.6)]) elif experiment == Experiments.Train_Mixed_Test_Mixed: # train: 2120 orig + 1142 mods # test: 530 orig + 286 mods subgroup_size = [2120, 1142, 530, 286] subgroup_names = [f'Originals\n{subgroup_size[0]}', f'Adversarial Examples\n{subgroup_size[1]}', f'Originals\n{subgroup_size[2]}', f'Adversarial Examples\n{subgroup_size[3]}'] mypie2, _ = ax.pie(subgroup_size, radius=1.3 - 0.3, labels=subgroup_names, labeldistance=0.7, colors=[b(0.6), y(0.2), b(0.6), y(0.2)]) plt.setp(mypie2, width=0.4, edgecolor='white') plt.margins(0, 0) plt.title(experiment.name, loc='left') current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', f'{experiment.name}_data.png'))
def __init__(self, experiment): self.experiment = experiment self.REPO_PATH = get_repo_path() self.DATA_PATH = os.path.join(self.REPO_PATH, '_data', 'icwsm2020_data') self.ALL_DATA_FILE_PATH = os.path.join(self.DATA_PATH, 'all_data.csv') # ALL_DATA_ANNOTATIONS_FILE_PATH = os.path.join(DATA_DIR, 'all_data_annotations.csv') self.all_data = pd.read_csv(self.ALL_DATA_FILE_PATH, sep='\t') # self.all_data_annotations = pd.read_csv(self.ALL_DATA_ANNOTATIONS_FILE_PATH, sep='\t') # load data in dataframe orig_mods_df = pd.DataFrame(self.all_data) orig_mods_df.dropna(axis=0, subset=['sexist'], inplace=True) # drop NAs orig_df = orig_mods_df[orig_mods_df['of_id'].isnull()] mods_df = orig_mods_df[orig_mods_df['of_id'].notnull()] MAX_DATA_SIZE = len(orig_df) # proportions TRAIN_SIZE = round(MAX_DATA_SIZE * 0.8) # 3262 TEST_SIZE = round(MAX_DATA_SIZE * 0.2) # 816 TRAIN_SPLIT_ORG = round(TRAIN_SIZE * 0.65) # 2120 TRAIN_SPLIT_MOD = round(TRAIN_SIZE - TRAIN_SPLIT_ORG) # 1142 TEST_SPLIT_ORG = round(TEST_SIZE * 0.65) # 530 TEST_SPLIT_MOD = round(TEST_SIZE - TEST_SPLIT_ORG) # 286 if experiment == Experiments.Train_Orig_Test_Orig: # train: 3262 orig # test: 816 orig self.X = orig_df['text'].values.tolist() y = pd.DataFrame(orig_df['sexist']) y['sexist'] = LabelEncoder().fit_transform(y['sexist']) self.y = y.to_numpy().ravel() self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.X, self.y, test_size=0.20) elif experiment == Experiments.Train_Orig_Test_Mixed: # train: 3262 orig # test: 530 orig + 286 mods # train train_df = orig_df.sample(TRAIN_SIZE) # test test_org_df = orig_df[~orig_df.isin(train_df)].dropna(how='all').sample(TEST_SPLIT_ORG) test_mod_df = mods_df.sample(TEST_SPLIT_MOD) test_df = test_org_df.append(test_mod_df, ignore_index=True) self.X_train, self.X_test, self.y_train, self.y_test = self.process(train_df, test_df) elif experiment == Experiments.Train_Mixed_Test_Orig: # train: 2120 orig + 1142 mods # test: 816 orig # test test_df = orig_df.sample(TEST_SIZE) # train train_org_df = orig_df[~orig_df.isin(test_df)].dropna(how='all').sample(TRAIN_SPLIT_ORG) train_mod_df = mods_df.sample(TRAIN_SPLIT_MOD) train_df = train_org_df.append(train_mod_df, ignore_index=True) self.X_train, self.X_test, self.y_train, self.y_test = self.process(train_df, test_df) elif experiment == Experiments.Train_Mixed_Test_Mixed: # train: 2120 orig + 1142 mods # test: 530 orig + 286 mods # train train_org_df = orig_df.sample(TRAIN_SPLIT_ORG) train_mod_df = mods_df.sample(TRAIN_SPLIT_MOD) train_df = train_org_df.append(train_mod_df, ignore_index=True) # test test_org_df = orig_df[~orig_df.isin(train_org_df)].dropna(how='all').sample(TEST_SPLIT_ORG) test_mod_df = mods_df[~mods_df.isin(train_mod_df)].dropna(how='all').sample(TEST_SPLIT_MOD) test_df = test_org_df.append(test_mod_df, ignore_index=True) self.X_train, self.X_test, self.y_train, self.y_test = self.process(train_df, test_df) else: raise ValueError('Unknown Experiment')
def which_datasets_are_explainable(k=5): labels = ['TOTO', 'TMTO', 'TOTM', 'TMTM'] explainable_b = np.array([]) explainable_h = np.array([]) explainable_o = np.array([]) explainable_c = np.array([]) explainable_s = np.array([]) for experiment in tqdm(Experiments): with open( os.path.join(get_experiment_path(experiment), f'explainable_tweets_k{k}.pkl'), 'rb') as f: explainable = pickle.load(f) explainable_b = np.append( explainable_b, len([ e for e in tqdm(explainable) if _get_dataset_of_tweet(e) == 'benevolent' ])) explainable_h = np.append( explainable_h, len([ e for e in tqdm(explainable) if _get_dataset_of_tweet(e) == 'hostile' ])) explainable_o = np.append( explainable_o, len([ e for e in tqdm(explainable) if _get_dataset_of_tweet(e) == 'other' ])) explainable_c = np.append( explainable_c, len([ e for e in tqdm(explainable) if _get_dataset_of_tweet(e) == 'callme' ])) explainable_s = np.append( explainable_s, len([ e for e in tqdm(explainable) if _get_dataset_of_tweet(e) == 'scales' ])) explainable_b = explainable_b / len(explainable) explainable_h = explainable_h / len(explainable) explainable_o = explainable_o / len(explainable) explainable_c = explainable_c / len(explainable) explainable_s = explainable_s / len(explainable) c1, c2, c3, c4, c5 = plt.cm.Set1.colors[:5] plt.bar(labels, [1] * len(explainable_b), color=c1, label='benevolent') plt.bar(labels, explainable_h + explainable_o + explainable_c + explainable_s, color=c2, label='hostile') plt.bar(labels, explainable_o + explainable_c + explainable_s, color=c3, label='other') plt.bar(labels, explainable_c + explainable_s, color=c4, label='callme') plt.bar(labels, explainable_s, color=c5, label='scales') plt.ylabel('Proportion of explainable tweets') plt.title('Which datasets are explainable?') plt.legend() current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', 'which_datasets_are_explainable.png'))
def which_length_is_explainable(k=5): fig = plt.figure() for i, experiment in enumerate(Experiments): graph = fig.add_subplot(2, 2, i + 1) with open( os.path.join(get_experiment_path(experiment), 'used_data', 'X_test_raw.pkl'), 'rb') as f: all_raw = pickle.load(f) with open( os.path.join(get_experiment_path(experiment), 'used_data', 'X_test.pkl'), 'rb') as f: all_tokens = pickle.load(f) with open( os.path.join(get_experiment_path(experiment), f'explainable_tweets_k{k}.pkl'), 'rb') as f: explainable = pickle.load(f) explainable_raw = [et.raw for et in explainable] explainable_tokens = [et.tokens for et in explainable] not_explainable_raw = np.setdiff1d(all_raw, explainable_raw) not_explainable_tokens = np.setdiff1d(all_tokens, explainable_tokens) result = {} result['all_raw'] = {} result['all_raw']['amount'] = len(all_raw) result['explainable'] = {} result['explainable']['amount'] = len(explainable_raw) result['explainable']['min_tokens'] = min( [len(t.split()) for t in explainable_tokens]) result['explainable']['max_tokens'] = max( [len(t.split()) for t in explainable_tokens]) result['explainable']['min_raw_length'] = min( [len(t) for t in explainable_raw]) result['explainable']['max_raw_length'] = max( [len(t) for t in explainable_raw]) xs = np.array(range(len(explainable_tokens))) / len(explainable_tokens) exp_ys = sorted([len(t.split()) for t in explainable_tokens]) graph.plot(xs, exp_ys, label='explainable') result['not_explainable'] = {} result['not_explainable']['amount'] = len(not_explainable_raw) result['not_explainable']['min_tokens'] = min( [len(t.split()) for t in not_explainable_tokens]) result['not_explainable']['max_tokens'] = max( [len(t.split()) for t in not_explainable_tokens]) result['not_explainable']['min_raw_length'] = min( [len(t) for t in not_explainable_raw]) result['not_explainable']['max_raw_length'] = max( [len(t) for t in not_explainable_raw]) xs = np.array(range( len(not_explainable_tokens))) / len(not_explainable_tokens) unexp_ys = sorted([len(t.split()) for t in not_explainable_tokens]) ttest_p = round(ttest_ind(exp_ys, unexp_ys).pvalue, 4) mwu_p = round(mannwhitneyu(exp_ys, unexp_ys).pvalue, 4) print(f"\n{experiment.name} T-Test, P-Value: ", ttest_p) print(f"{experiment.name} Mann-Whitney U Test: P-Value: ", mwu_p) graph.set_xlabel(f'$tweets$') graph.set_ylabel('$number\_of\_tokens$') graph.text(0.48, 1.5, f'T-Test: p={ttest_p}\nMWU-Test: p={mwu_p}', fontsize=8) if i + 1 in [1, 2]: graph.axes.xaxis.set_visible(False) if i + 1 in [2, 4]: graph.axes.yaxis.set_visible(False) graph.plot(xs, unexp_ys, label='unexplainable') graph.set_title(experiment.name, fontdict={'fontsize': 10}) plt.legend() current_fig = plt.gcf() plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', 'which_length_is_explainable.png'))
def _load_labels(self): path = os.path.join(get_repo_path(), '_experiments', self.experiment.name) with open(os.path.join(path, 'used_data', 'y_test.pkl'), 'rb') as f: labels = pickle.load(f) return labels
from xgboost.sklearn import XGBClassifier import spacy from _utils.fastbert import get_fastbert_model from _utils.pathfinder import get_experiment_path import numpy as np import pandas as pd from sklearn.model_selection import StratifiedKFold # from fastbert import FastBERT import os import logging from _utils.pathfinder import get_repo_path, get_experiment_path import pickle nlp = spacy.load("en_core_web_sm") REPO_DIR = get_repo_path() MODELS_DIR = os.path.join(REPO_DIR, '_trained_models') REPORTS_DIR = os.path.join(REPO_DIR, '_classification_reports') def train(model, X_train, X_test, y_train, y_test, save_as=None, report=None, experiment=None): """ Args: model (Model object): model to train
def create_similarity_graphs(models=None, metrics=None, ks=None): if models is None: models = ['lr', 'svm', 'xgboost'] if metrics is None: metrics = ['cosine', 'dice', 'jaccard', 'raw', 'overlap'] if ks is None: ks = [3, 5, 7] ex_methods = ['lime', 'shap', 'builtin'] means = {} backup = {} for model in models: backup[model] = {} for metric in metrics: means[metric] = {} backup[model][metric] = {} for k in ks: backup[model][metric][k] = {} fig = plt.figure() means[metric][k] = [] for i, experiment in enumerate(Experiments): exps = {} backup[model][metric][k][experiment] = {} graph = fig.add_subplot(2, 2, i + 1) tweet_loader = TweetLoader(experiment) explanation_loader = ExplanationLoader( experiment, tweet_loader=tweet_loader) explainable_tweets = explanation_loader.get_explainable_tweets( k) for ex_method in ex_methods: exps[ex_method] = [] for explainable_tweet in explainable_tweets: exps[ex_method].append( explainable_tweet.explanations[model] [ex_method]) exps[ex_method] = np.array( list(map(lambda x: ' '.join(x), exps[ex_method]))) exps[ex_method] = transform(model, experiment, exps[ex_method]) ########## calculate metric ########## values_lime_shap, values_lime_builtin, values_shap_builtin = Metrics.my_coefficient( exps, METRIC_FUNCTIONS[metric]) # backup backup[model][metric][k][experiment][ 'values_lime_shap'] = values_lime_shap backup[model][metric][k][experiment][ 'values_lime_builtin'] = values_lime_builtin backup[model][metric][k][experiment][ 'values_shap_builtin'] = values_shap_builtin ##################################### graph.set_xlabel(f'${metric}\_coefficient$') graph.set_ylabel('$p$') if i + 1 in [1, 2]: graph.axes.xaxis.set_visible(False) if i + 1 in [2, 4]: graph.axes.yaxis.set_visible(False) # plot the CDF y = np.array(range(len(values_lime_shap))) / float( len(values_lime_shap)) graph.plot(values_lime_shap, y, color='blue', label='LIME - SHAP') graph.plot(values_lime_builtin, y, color='red', label='LIME - builtin') graph.plot(values_shap_builtin, y, color='green', label='SHAP - builtin') f1 = get_f1_score(model, experiment) graph.set_title(f"{experiment.name} (F1: {f1})", fontdict={'fontsize': 10}) means[metric][k].append( np.nanmean(values_lime_shap, dtype=np.float64)) means[metric][k].append( np.nanmean(values_lime_builtin, dtype=np.float64)) means[metric][k].append( np.nanmean(values_shap_builtin, dtype=np.float64)) # fig.suptitle(f'{t(model)} (k={k}, measure={metric})', size=15) plt.legend() current_fig = plt.gcf() # plt.show() current_fig.savefig( os.path.join(get_repo_path(), '_evaluation', 'graphs', metric, f'{model}_k{k}_{metric}_sim_cdf.png')) # values = sorted(means, key=lambda i: i[0]) # xs = [i[0] for i in values] # ys = [i[1] for i in values] # plt.title(model) # plt.xlabel('F1-Score') # plt.ylabel(f'${metric}\_coefficient$') # plt.plot(xs, ys) # plt.show() pickle.dump( means, open(os.path.join(get_repo_path(), '_evaluation', 'means.pkl'), 'wb')) pickle.dump( backup, open( os.path.join(get_repo_path(), '_evaluation', 'similarity_backup.pkl'), 'wb'))
print("-"*20, "SHAP eval begin", "-"*20) # data_save_folder = data_dir.split("/")[-1] total_inst_num = all_input_ids.size()[0] for i in range(total_inst_num): eval_one_hot = id2onehot(all_input_ids[i:i+1], model.config.vocab_size) shap_value = explainer.shap_values(X=eval_one_hot, eval_mask=all_input_mask[i:i+1], seg=all_segment_ids[i:i+1], tk_idx=all_input_ids[i:i+1], ranked_outputs=None) values = [] for lb in range(num_labels): tks = all_input_ids[i:i+1] seq_len = tks.size()[1] right_value = shap_value[lb][0,torch.arange(0, seq_len).long(), tks[0, :]] values.append(right_value) shap_values.append(values) # if i % 5 == 0 and i != 0: # with open('data/SHAP_features/'+data_save_folder+'-'+str(i)+'.npz', 'wb') as f: # np.save(f, shap_values) # shap_values = [] with open(output_dir + '/fast-bert-shap.npy', 'wb') as f: np.save(f, shap_values) if __name__ == '__main__': compute_shap_for_fastbert( os.path.join(get_repo_path(), '_data', 'as_csv'), os.path.join(get_repo_path(), '_trained_models', 'fast-bert'), os.path.join(get_repo_path(), '_explanations'), no_cuda=True )