def main(): (train_data, val_data, test_data), tokenizer, args = get_data_and_args() model = get_model(args) save_root = '' if args.load is None else args.load save_root = save_root.replace('.current', '') save_root = os.path.splitext(save_root)[0] save_root += '_transfer' save_root = os.path.join(save_root, args.save_results) if not os.path.exists(save_root): os.makedirs(save_root) print('writing results to ' + save_root) # featurize train, val, test or use previously cached features if possible print('transforming train') if not (os.path.exists(os.path.join(save_root, 'trXt.npy')) and args.use_cached): trXt, trY = transform(model, train_data, args) np.save(os.path.join(save_root, 'trXt'), trXt) np.save(os.path.join(save_root, 'trY'), trY) else: trXt = np.load(os.path.join(save_root, 'trXt.npy')) trY = np.load(os.path.join(save_root, 'trY.npy')) vaXt, vaY = None, None if val_data is not None: print('transforming validation') if not (os.path.exists(os.path.join(save_root, 'vaXt.npy')) and args.use_cached): vaXt, vaY = transform(model, val_data, args) np.save(os.path.join(save_root, 'vaXt'), vaXt) np.save(os.path.join(save_root, 'vaY'), vaY) else: vaXt = np.load(os.path.join(save_root, 'vaXt.npy')) vaY = np.load(os.path.join(save_root, 'vaY.npy')) teXt, teY = None, None if test_data is not None: print('transforming test') if not (os.path.exists(os.path.join(save_root, 'teXt.npy')) and args.use_cached): teXt, teY = transform(model, test_data, args) np.save(os.path.join(save_root, 'teXt'), teXt) np.save(os.path.join(save_root, 'teY'), teY) else: teXt = np.load(os.path.join(save_root, 'teXt.npy')) teY = np.load(os.path.join(save_root, 'teY.npy')) # train logistic regression model of featurized text against labels start = time.time() metric = 'mcc' if args.mcc else 'acc' logreg_model, logreg_scores, logreg_probs, c, nnotzero = train_logreg( trXt, trY, vaXt, vaY, teXt, teY, max_iter=args.epochs, eval_test=not args.no_test_eval, seed=args.seed, report_metric=metric, threshold_metric=metric) end = time.time() elapsed_time = end - start with open(os.path.join(save_root, 'all_neurons_score.txt'), 'w') as f: f.write(str(logreg_scores)) with open(os.path.join(save_root, 'all_neurons_probs.pkl'), 'wb') as f: pkl.dump(logreg_probs, f) with open(os.path.join(save_root, 'neurons.pkl'), 'wb') as f: pkl.dump(logreg_model.coef_, f) print('all neuron regression took %s seconds' % (str(elapsed_time))) print(', '.join([str(score) for score in logreg_scores]), 'train, val, test accuracy for all neuron regression') print(str(c) + ' regularization coefficient used') print(str(nnotzero) + ' features used in all neuron regression\n') # save a sentiment classification pytorch model sd = {} if not args.fp16: clf_sd = { 'weight': torch.from_numpy(logreg_model.coef_).float(), 'bias': torch.from_numpy(logreg_model.intercept_).float() } else: clf_sd = { 'weight': torch.from_numpy(logreg_model.coef_).half(), 'bias': torch.from_numpy(logreg_model.intercept_).half() } sd['classifier'] = clf_sd model.float().cpu() sd['lm_encoder'] = model.state_dict() with open(os.path.join(save_root, 'classifier.pt'), 'wb') as f: torch.save(sd, f) model.half() sd['lm_encoder'] = model.state_dict() with open(os.path.join(save_root, 'classifier.pt.16'), 'wb') as f: torch.save(sd, f) # extract sentiment neuron indices sentiment_neurons = get_top_k_neuron_weights(logreg_model, args.neurons) print('using neuron(s) %s as features for regression' % (', '.join( [str(neuron) for neuron in list(sentiment_neurons.reshape(-1))]))) # train logistic regression model of features corresponding to sentiment neuron indices against labels start = time.time() logreg_neuron_model, logreg_neuron_scores, logreg_neuron_probs, neuron_c, neuron_nnotzero = train_logreg( trXt, trY, vaXt, vaY, teXt, teY, max_iter=args.epochs, eval_test=not args.no_test_eval, seed=args.seed, neurons=sentiment_neurons, drop_neurons=args.drop_neurons, report_metric=metric, threshold_metric=metric) end = time.time() if args.drop_neurons: with open(os.path.join(save_root, 'dropped_neurons_score.txt'), 'w') as f: f.write(str(logreg_neuron_scores)) with open(os.path.join(save_root, 'dropped_neurons_probs.pkl'), 'wb') as f: pkl.dump(logreg_neuron_probs, f) print('%d dropped neuron regression took %s seconds' % (args.neurons, str(end - start))) print( ', '.join([str(score) for score in logreg_neuron_scores]), 'train, val, test accuracy for %d dropped neuron regression' % (args.neurons)) print(str(neuron_c) + ' regularization coefficient used') start = time.time() logreg_neuron_model, logreg_neuron_scores, logreg_neuron_probs, neuron_c, neuron_nnotzero = train_logreg( trXt, trY, vaXt, vaY, teXt, teY, max_iter=args.epochs, eval_test=not args.no_test_eval, seed=args.seed, neurons=sentiment_neurons, report_metric=metric, threshold_metric=metric) end = time.time() print('%d neuron regression took %s seconds' % (args.neurons, str(end - start))) print( ', '.join([str(score) for score in logreg_neuron_scores]), 'train, val, test accuracy for %d neuron regression' % (args.neurons)) print(str(neuron_c) + ' regularization coefficient used') # log model accuracies, predicted probabilities, and weight/bias of regression model with open(os.path.join(save_root, 'all_neurons_score.txt'), 'w') as f: f.write(str(logreg_scores)) with open(os.path.join(save_root, 'neurons_score.txt'), 'w') as f: f.write(str(logreg_neuron_scores)) with open(os.path.join(save_root, 'all_neurons_probs.pkl'), 'wb') as f: pkl.dump(logreg_probs, f) with open(os.path.join(save_root, 'neurons_probs.pkl'), 'wb') as f: pkl.dump(logreg_neuron_probs, f) with open(os.path.join(save_root, 'neurons.pkl'), 'wb') as f: pkl.dump(logreg_model.coef_, f) with open(os.path.join(save_root, 'neuron_bias.pkl'), 'wb') as f: pkl.dump(logreg_model.intercept_, f) #Plot feats use_feats, use_labels = teXt, teY if use_feats is None: use_feats, use_labels = vaXt, vaY if use_feats is None: use_feats, use_labels = trXt, trY try: plot_logits(save_root, use_feats, use_labels, sentiment_neurons) except: print('no labels to plot logits for') plot_weight_contribs_and_save(logreg_model.coef_, os.path.join(save_root, 'weight_vis.png')) print('results successfully written to ' + save_root) if args.write_results == '': exit() def get_csv_writer(feats, top_neurons, all_proba, neuron_proba): """makes a generator to be used in data_utils.datasets.csv_dataset.write()""" header = ['prob w/ all', 'prob w/ %d neuron(s)' % (len(top_neurons), )] top_feats = feats[:, top_neurons] header += ['neuron %s' % (str(x), ) for x in top_neurons] yield header for i, _ in enumerate(top_feats): row = [] row.append(all_proba[i]) row.append(neuron_proba[i]) row.extend(list(top_feats[i].reshape(-1))) yield row data, use_feats = test_data, teXt if use_feats is None: data, use_feats = val_data, vaXt if use_feats is None: data, use_feats = train_data, trXt csv_writer = get_csv_writer(use_feats, sentiment_neurons, logreg_probs[-1], logreg_neuron_probs[-1]) data.dataset.write(csv_writer, path=args.write_results)
def main(): (train_data, val_data, test_data), tokenizer, args = get_data_and_args() # Print args for logging & reproduction. Need to know, including default args if test_data is None: test_data = val_data model, optim, LR = get_model_and_optim(args, train_data) # save_root = '' if args.load is None else args.load # save_root = save_root.replace('.current', '') # save_root = os.path.splitext(save_root)[0] # save_root += '_transfer' save_root = os.path.join('', args.model_version_name) if not os.path.exists(save_root): os.makedirs(save_root) print('writing results to ' + save_root) def clf_reg_loss(reg_penalty=.125, order=1): loss = 0 for p in model.classifier.parameters(): loss += torch.abs(p).sum() * reg_penalty return loss reg_loss = clf_reg_loss init_params = list(model.lm_encoder.parameters()) if args.use_logreg: def transform_for_logreg(model, data, args, desc='train'): if data is None: return None, None X_out = [] Y_out = [] for i, batch in tqdm(enumerate(data), total=len(data), unit="batch", desc=desc, position=0, ncols=100): text_batch, labels_batch, length_batch = get_supervised_batch( batch, args.cuda, model, args.max_seq_len, args, heads_per_class=args.heads_per_class) # if args.non_binary_cols: # labels_batch = labels_batch[:,0]-labels_batch[:,1]+1 _, (_, state) = transform(model, text_batch, labels_batch, length_batch, args) X_out.append(state.cpu().numpy()) Y_out.append(labels_batch.cpu().numpy()) X_out = np.concatenate(X_out) Y_out = np.concatenate(Y_out) return X_out, Y_out model.eval() trX, trY = transform_for_logreg(model, train_data, args, desc='train') vaX, vaY = transform_for_logreg(model, val_data, args, desc='val') teX, teY = transform_for_logreg(model, test_data, args, desc='test') logreg_model, logreg_scores, logreg_preds, c, nnotzero = train_logreg( trX, trY, vaX, vaY, teX, teY, eval_test=not args.no_test_eval, report_metric=args.report_metric, threshold_metric=args.threshold_metric, automatic_thresholding=args.automatic_thresholding, micro=args.micro) print(', '.join([str(score) for score in logreg_scores]), 'train, val, test accuracy for all neuron regression') print(str(c) + ' regularization coefficient used') print(str(nnotzero) + ' features used in all neuron regression\n') else: best_vaY = 0 vaT = [ ] # Current "best thresholds" so we can get reasonable estimates on training set for e in tqdm(range(args.epochs), unit="epoch", desc="epochs", position=0, ncols=100): if args.use_softmax: vaT = [] save_outputs = False report_metrics = [ 'jacc', 'acc', 'mcc', 'f1', 'recall', 'precision', 'var' ] if args.all_metrics else [args.report_metric] print_str = "" trXt, trY, trC, _ = finetune(model, train_data, args, val_data=val_data, LR=LR, reg_loss=reg_loss, tqdm_desc='train', heads_per_class=args.heads_per_class, last_thresholds=vaT, threshold_validation=False) data_str_base = "Train Loss: {:4.2f} Train {:5s} (All): {:5.2f}, Train Class {:5s}: {}" for idx, m in enumerate(report_metrics): data_str = data_str_base.format(trXt, m, trY[idx] * 100, m, trC[idx]) print_str += data_str + " " * max(0, 110 - len(data_str)) + "\n" vaXt, vaY = None, None if val_data is not None: vaXt, vaY, vaC, vaT = finetune( model, val_data, args, tqdm_desc='val', heads_per_class=args.heads_per_class, last_thresholds=vaT) # Take command line, for metric for which to measure best performance against. # NOTE: F1, MCC, Jaccard are good measures. Accuracy is not -- since so skewed. selection_metric = [ 'jacc', 'acc', 'mcc', 'f1', 'recall', 'precision', 'var' ].index(args.threshold_metric) avg_Y = vaY[selection_metric] tqdm.write('avg ' + args.threshold_metric + ' metric ' + str(avg_Y)) if avg_Y > best_vaY: save_outputs = True best_vaY = avg_Y elif avg_Y == best_vaY and random.random() > 0.5: save_outputs = True best_vaY = avg_Y data_str_base = "Val Loss: {:4.2f} Val {:5s} (All): {:5.2f}, Val Class {:5s}: {}" for idx, m in enumerate(report_metrics): data_str = data_str_base.format(vaXt, m, vaY[idx] * 100, m, vaC[idx]) print_str += data_str + " " * max( 0, 110 - len(data_str)) + "\n" tqdm.write(print_str[:-1]) teXt, teY = None, None if test_data is not None: # Hardcode -- enable to always save outputs [regardless of metrics] # save_outputs = True if save_outputs: tqdm.write('performing test eval') try: with torch.no_grad(): if not args.no_test_eval: auto_thresholds = None dual_thresholds = None # NOTE -- we manually threshold to F1 [not necessarily good] V_pred, V_label, V_std = generate_outputs( model, val_data, args) if args.automatic_thresholding: if args.dual_thresh: # get dual threshold (do not call auto thresholds) # TODO: Handle multiple heads per class _, dual_thresholds = _neutral_threshold_two_output( V_pred.cpu().numpy(), V_label.cpu().numpy()) model.set_thresholds( dual_thresholds, dual_threshold=args.dual_thresh and not args.joint_binary_train) else: # Use args.threshold_metric to choose which category to threshold on. F1 and Jaccard are good options # NOTE: For multiple heads per class, can threshold each head (default) or single threshold. Little difference once model converges. auto_thresholds = vaT # _, auto_thresholds, _, _ = _binary_threshold(V_pred.view(-1, int(model.out_dim/args.heads_per_class)).contiguous(), V_label.view(-1, int(model.out_dim/args.heads_per_class)).contiguous(), # args.threshold_metric, args.micro, global_tweaks=args.global_tweaks) model.set_thresholds( auto_thresholds, args.double_thresh) T_pred, T_label, T_std = generate_outputs( model, test_data, args, auto_thresholds) if not args.use_softmax and int( model.out_dim / args.heads_per_class) > 1: keys = list(args.non_binary_cols) if args.dual_thresh: if len(keys) == len(dual_thresholds): tqdm.write( 'Dual thresholds: %s' % str( list( zip( keys, dual_thresholds)))) keys += ['neutral'] else: tqdm.write( 'Class thresholds: %s' % str( list(zip( keys, auto_thresholds)))) elif args.use_softmax: keys = [ str(m) for m in range(model.out_dim) ] else: tqdm.write('Class threshold: %s' % str( [args.label_key, auto_thresholds[0]])) keys = [''] info_dicts = [{ 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0., 'metric': args.report_metric, 'micro': True } for k in keys] #perform dual threshold here, adding the neutral labels to T_label, thresholding existing predictions and adding neutral preds to T_Pred if args.dual_thresh: if dual_thresholds is None: dual_thresholds = [.5, .5] def make_onehot_w_neutral(label): rtn = [0] * 3 rtn[label] = 1 return rtn def get_label(pos_neg): thresholded = [ pos_neg[0] >= dual_thresholds[0], pos_neg[1] >= dual_thresholds[1] ] if thresholded[0] == thresholded[1]: return 2 return thresholded.index(1) def get_new_std(std): return std[0], std[1], (std[0] + std[1]) / 2 new_labels = [] new_preds = [] T_std = torch.cat([ T_std[:, :2], T_std[:, :2].mean(-1).view(-1, 1) ], -1).cpu().numpy() for j, lab in enumerate(T_label): pred = T_pred[j] new_preds.append( make_onehot_w_neutral( get_label(pred))) new_labels.append( make_onehot_w_neutral( get_label(lab))) T_pred = np.array(new_preds) T_label = np.array(new_labels) # HACK: If dual threshold, hardcoded -- assume positive, negative and neutral -- in that order # It's ok to train with other categories (after positive, neutral) as auxilary loss -- but won't calculate in test if args.dual_thresh and args.joint_binary_train: keys = ['positive', 'negative', 'neutral'] info_dicts = [{ 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0., 'metric': args.report_metric, 'micro': True } for k in keys] for j, k in enumerate(keys): update_info_dict(info_dicts[j], T_pred[:, j], T_label[:, j], std=T_std[:, j]) total_metrics, metric_strings = get_metric_report( info_dicts, args, keys) test_str = '' test_str_base = "Test {:5s} (micro): {:5.2f}, Test Class {:5s}: {}" for idx, m in enumerate(report_metrics): data_str = test_str_base.format( m, total_metrics[idx] * 100, m, metric_strings[idx]) test_str += data_str + " " * max( 0, 110 - len(data_str)) + "\n" tqdm.write(test_str[:-1]) # tqdm.write(str(total_metrics)) # tqdm.write('; '.join(metric_strings)) else: V_pred, V_label, V_std = generate_outputs( model, val_data, args) T_pred, T_label, T_std = generate_outputs( model, test_data, args) val_path = os.path.join(save_root, 'val_results.txt') tqdm.write( 'Saving validation prediction results of size %s to %s' % (str(T_pred.shape[:]), val_path)) write_results(V_pred, V_label, val_path) test_path = os.path.join(save_root, 'test_results.txt') tqdm.write( 'Saving test prediction results of size %s to %s' % (str(T_pred.shape[:]), test_path)) write_results(T_pred, T_label, test_path) except KeyboardInterrupt: pass else: pass # Save the model, upon request if args.save_finetune and save_outputs: # Save model if best so far. Note epoch number, and also keys [what is it predicting], as well as optional version number # TODO: Add key string to handle multiple runs? if args.non_binary_cols: keys = list(args.non_binary_cols) else: keys = [args.label_key] # Also save args args_save_path = os.path.join(save_root, 'args.txt') tqdm.write('Saving commandline to %s' % args_save_path) with open(args_save_path, 'w') as f: f.write(' '.join(sys.argv[1:])) # Save and add thresholds to arguments for easy reloading of model config if not args.no_test_eval and args.automatic_thresholding: thresh_save_path = os.path.join( save_root, 'thresh' + '_ep' + str(e) + '.npy') tqdm.write('Saving thresh to %s' % thresh_save_path) if args.dual_thresh: np.save(thresh_save_path, list(zip(keys, dual_thresholds))) args.thresholds = list(zip(keys, dual_thresholds)) else: np.save(thresh_save_path, list(zip(keys, auto_thresholds))) args.thresholds = list(zip(keys, auto_thresholds)) else: args.thresholds = None args.classes = keys #save full model with args to restore clf_save_path = os.path.join(save_root, 'model' + '_ep' + str(e) + '.clf') tqdm.write('Saving full classifier to %s' % clf_save_path) torch.save({ 'sd': model.state_dict(), 'args': args }, clf_save_path)