def _binary_threshold(preds, labels, metric, micro, global_tweaks=1000, debug=False, heads_per_class=1, class_single_threshold=False): avg_metric = 0 best_thresholds = [] info_dicts = [] category_metrics = [] # Compute threshold per class... *unless* multiple heads per class and one threshold required. num_categories = labels.shape[1] for category in range(num_categories): category_best_threshold = category_best_metric = 0 for threshold in np.linspace(0.005, 1, 200): if heads_per_class > 1 and class_single_threshold: info_dict = update_info_dict( defaultdict(int), labels[:, (category * heads_per_class):(category + 1) * heads_per_class], preds[:, (category * heads_per_class):(category + 1) * heads_per_class], threshold=threshold) else: info_dict = update_info_dict(defaultdict(int), labels[:, category], preds[:, category], threshold=threshold) metric_score = get_metric(info_dict, metric, micro) if metric_score > category_best_metric or category_best_metric == 0: category_best_metric, category_best_threshold, category_best_info_dict = metric_score, threshold, info_dict info_dicts.append(category_best_info_dict) category_metrics.append(category_best_metric) best_thresholds.append(category_best_threshold) # HACK -- use micro average here, even if not elsewhere micro = True best_metric = get_metric(info_dicts, metric, micro) # HACK: Attempt to tune thresholds simultaneously... for overall micro average if num_categories < 2: global_tweaks = 0 if debug and global_tweaks > 0: print('best after invididual thresholds (micro %s)' % micro) print(best_thresholds) print(get_metric(info_dicts, metric, micro)) for i in range(global_tweaks): # Choose random category category = np.random.randint(num_categories) curr_threshold = best_thresholds[category] # tweak randomly new_threshold = curr_threshold + (0.08 * (np.random.random() - 0.5)) if heads_per_class > 1 and class_single_threshold: info_dict = update_info_dict( defaultdict(int), labels[:, (category * heads_per_class):(category + 1) * heads_per_class], preds[:, (category * heads_per_class):(category + 1) * heads_per_class], threshold=new_threshold) else: info_dict = update_info_dict(defaultdict(int), labels[:, category], preds[:, category], threshold=new_threshold) old_dict = info_dicts[category] info_dicts[category] = info_dict # compute *global* metrics metric_score = get_metric(info_dicts, metric, micro) # save new threshold if global metrics improve if metric_score > best_metric: # print('Better threshold %.3f for category %d' % (new_threshold, category)) best_thresholds[category] = round(new_threshold, 3) best_metric = metric_score else: info_dicts[category] = old_dict if debug and global_tweaks > 0: print('final thresholds') print(best_thresholds) print(get_metric(info_dicts, metric, micro)) # OK, now *if* we used multiple heads per class (same threshold) copy these back out to final answers if heads_per_class > 1 and class_single_threshold: best_thresholds = np.concatenate([[best_thresholds[i]] * heads_per_class for i in range(num_categories)]) else: best_thresholds = np.array(best_thresholds) # print(best_thresholds) return get_metric(info_dicts, metric, micro), best_thresholds, category_metrics, info_dicts
def main(): (train_data, val_data, test_data), tokenizer, args = get_data_and_args() # Print args for logging & reproduction. Need to know, including default args if test_data is None: test_data = val_data model, optim, LR = get_model_and_optim(args, train_data) # save_root = '' if args.load is None else args.load # save_root = save_root.replace('.current', '') # save_root = os.path.splitext(save_root)[0] # save_root += '_transfer' save_root = os.path.join('', args.model_version_name) if not os.path.exists(save_root): os.makedirs(save_root) print('writing results to ' + save_root) def clf_reg_loss(reg_penalty=.125, order=1): loss = 0 for p in model.classifier.parameters(): loss += torch.abs(p).sum() * reg_penalty return loss reg_loss = clf_reg_loss init_params = list(model.lm_encoder.parameters()) if args.use_logreg: def transform_for_logreg(model, data, args, desc='train'): if data is None: return None, None X_out = [] Y_out = [] for i, batch in tqdm(enumerate(data), total=len(data), unit="batch", desc=desc, position=0, ncols=100): text_batch, labels_batch, length_batch = get_supervised_batch( batch, args.cuda, model, args.max_seq_len, args, heads_per_class=args.heads_per_class) # if args.non_binary_cols: # labels_batch = labels_batch[:,0]-labels_batch[:,1]+1 _, (_, state) = transform(model, text_batch, labels_batch, length_batch, args) X_out.append(state.cpu().numpy()) Y_out.append(labels_batch.cpu().numpy()) X_out = np.concatenate(X_out) Y_out = np.concatenate(Y_out) return X_out, Y_out model.eval() trX, trY = transform_for_logreg(model, train_data, args, desc='train') vaX, vaY = transform_for_logreg(model, val_data, args, desc='val') teX, teY = transform_for_logreg(model, test_data, args, desc='test') logreg_model, logreg_scores, logreg_preds, c, nnotzero = train_logreg( trX, trY, vaX, vaY, teX, teY, eval_test=not args.no_test_eval, report_metric=args.report_metric, threshold_metric=args.threshold_metric, automatic_thresholding=args.automatic_thresholding, micro=args.micro) print(', '.join([str(score) for score in logreg_scores]), 'train, val, test accuracy for all neuron regression') print(str(c) + ' regularization coefficient used') print(str(nnotzero) + ' features used in all neuron regression\n') else: best_vaY = 0 vaT = [ ] # Current "best thresholds" so we can get reasonable estimates on training set for e in tqdm(range(args.epochs), unit="epoch", desc="epochs", position=0, ncols=100): if args.use_softmax: vaT = [] save_outputs = False report_metrics = [ 'jacc', 'acc', 'mcc', 'f1', 'recall', 'precision', 'var' ] if args.all_metrics else [args.report_metric] print_str = "" trXt, trY, trC, _ = finetune(model, train_data, args, val_data=val_data, LR=LR, reg_loss=reg_loss, tqdm_desc='train', heads_per_class=args.heads_per_class, last_thresholds=vaT, threshold_validation=False) data_str_base = "Train Loss: {:4.2f} Train {:5s} (All): {:5.2f}, Train Class {:5s}: {}" for idx, m in enumerate(report_metrics): data_str = data_str_base.format(trXt, m, trY[idx] * 100, m, trC[idx]) print_str += data_str + " " * max(0, 110 - len(data_str)) + "\n" vaXt, vaY = None, None if val_data is not None: vaXt, vaY, vaC, vaT = finetune( model, val_data, args, tqdm_desc='val', heads_per_class=args.heads_per_class, last_thresholds=vaT) # Take command line, for metric for which to measure best performance against. # NOTE: F1, MCC, Jaccard are good measures. Accuracy is not -- since so skewed. selection_metric = [ 'jacc', 'acc', 'mcc', 'f1', 'recall', 'precision', 'var' ].index(args.threshold_metric) avg_Y = vaY[selection_metric] tqdm.write('avg ' + args.threshold_metric + ' metric ' + str(avg_Y)) if avg_Y > best_vaY: save_outputs = True best_vaY = avg_Y elif avg_Y == best_vaY and random.random() > 0.5: save_outputs = True best_vaY = avg_Y data_str_base = "Val Loss: {:4.2f} Val {:5s} (All): {:5.2f}, Val Class {:5s}: {}" for idx, m in enumerate(report_metrics): data_str = data_str_base.format(vaXt, m, vaY[idx] * 100, m, vaC[idx]) print_str += data_str + " " * max( 0, 110 - len(data_str)) + "\n" tqdm.write(print_str[:-1]) teXt, teY = None, None if test_data is not None: # Hardcode -- enable to always save outputs [regardless of metrics] # save_outputs = True if save_outputs: tqdm.write('performing test eval') try: with torch.no_grad(): if not args.no_test_eval: auto_thresholds = None dual_thresholds = None # NOTE -- we manually threshold to F1 [not necessarily good] V_pred, V_label, V_std = generate_outputs( model, val_data, args) if args.automatic_thresholding: if args.dual_thresh: # get dual threshold (do not call auto thresholds) # TODO: Handle multiple heads per class _, dual_thresholds = _neutral_threshold_two_output( V_pred.cpu().numpy(), V_label.cpu().numpy()) model.set_thresholds( dual_thresholds, dual_threshold=args.dual_thresh and not args.joint_binary_train) else: # Use args.threshold_metric to choose which category to threshold on. F1 and Jaccard are good options # NOTE: For multiple heads per class, can threshold each head (default) or single threshold. Little difference once model converges. auto_thresholds = vaT # _, auto_thresholds, _, _ = _binary_threshold(V_pred.view(-1, int(model.out_dim/args.heads_per_class)).contiguous(), V_label.view(-1, int(model.out_dim/args.heads_per_class)).contiguous(), # args.threshold_metric, args.micro, global_tweaks=args.global_tweaks) model.set_thresholds( auto_thresholds, args.double_thresh) T_pred, T_label, T_std = generate_outputs( model, test_data, args, auto_thresholds) if not args.use_softmax and int( model.out_dim / args.heads_per_class) > 1: keys = list(args.non_binary_cols) if args.dual_thresh: if len(keys) == len(dual_thresholds): tqdm.write( 'Dual thresholds: %s' % str( list( zip( keys, dual_thresholds)))) keys += ['neutral'] else: tqdm.write( 'Class thresholds: %s' % str( list(zip( keys, auto_thresholds)))) elif args.use_softmax: keys = [ str(m) for m in range(model.out_dim) ] else: tqdm.write('Class threshold: %s' % str( [args.label_key, auto_thresholds[0]])) keys = [''] info_dicts = [{ 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0., 'metric': args.report_metric, 'micro': True } for k in keys] #perform dual threshold here, adding the neutral labels to T_label, thresholding existing predictions and adding neutral preds to T_Pred if args.dual_thresh: if dual_thresholds is None: dual_thresholds = [.5, .5] def make_onehot_w_neutral(label): rtn = [0] * 3 rtn[label] = 1 return rtn def get_label(pos_neg): thresholded = [ pos_neg[0] >= dual_thresholds[0], pos_neg[1] >= dual_thresholds[1] ] if thresholded[0] == thresholded[1]: return 2 return thresholded.index(1) def get_new_std(std): return std[0], std[1], (std[0] + std[1]) / 2 new_labels = [] new_preds = [] T_std = torch.cat([ T_std[:, :2], T_std[:, :2].mean(-1).view(-1, 1) ], -1).cpu().numpy() for j, lab in enumerate(T_label): pred = T_pred[j] new_preds.append( make_onehot_w_neutral( get_label(pred))) new_labels.append( make_onehot_w_neutral( get_label(lab))) T_pred = np.array(new_preds) T_label = np.array(new_labels) # HACK: If dual threshold, hardcoded -- assume positive, negative and neutral -- in that order # It's ok to train with other categories (after positive, neutral) as auxilary loss -- but won't calculate in test if args.dual_thresh and args.joint_binary_train: keys = ['positive', 'negative', 'neutral'] info_dicts = [{ 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0., 'metric': args.report_metric, 'micro': True } for k in keys] for j, k in enumerate(keys): update_info_dict(info_dicts[j], T_pred[:, j], T_label[:, j], std=T_std[:, j]) total_metrics, metric_strings = get_metric_report( info_dicts, args, keys) test_str = '' test_str_base = "Test {:5s} (micro): {:5.2f}, Test Class {:5s}: {}" for idx, m in enumerate(report_metrics): data_str = test_str_base.format( m, total_metrics[idx] * 100, m, metric_strings[idx]) test_str += data_str + " " * max( 0, 110 - len(data_str)) + "\n" tqdm.write(test_str[:-1]) # tqdm.write(str(total_metrics)) # tqdm.write('; '.join(metric_strings)) else: V_pred, V_label, V_std = generate_outputs( model, val_data, args) T_pred, T_label, T_std = generate_outputs( model, test_data, args) val_path = os.path.join(save_root, 'val_results.txt') tqdm.write( 'Saving validation prediction results of size %s to %s' % (str(T_pred.shape[:]), val_path)) write_results(V_pred, V_label, val_path) test_path = os.path.join(save_root, 'test_results.txt') tqdm.write( 'Saving test prediction results of size %s to %s' % (str(T_pred.shape[:]), test_path)) write_results(T_pred, T_label, test_path) except KeyboardInterrupt: pass else: pass # Save the model, upon request if args.save_finetune and save_outputs: # Save model if best so far. Note epoch number, and also keys [what is it predicting], as well as optional version number # TODO: Add key string to handle multiple runs? if args.non_binary_cols: keys = list(args.non_binary_cols) else: keys = [args.label_key] # Also save args args_save_path = os.path.join(save_root, 'args.txt') tqdm.write('Saving commandline to %s' % args_save_path) with open(args_save_path, 'w') as f: f.write(' '.join(sys.argv[1:])) # Save and add thresholds to arguments for easy reloading of model config if not args.no_test_eval and args.automatic_thresholding: thresh_save_path = os.path.join( save_root, 'thresh' + '_ep' + str(e) + '.npy') tqdm.write('Saving thresh to %s' % thresh_save_path) if args.dual_thresh: np.save(thresh_save_path, list(zip(keys, dual_thresholds))) args.thresholds = list(zip(keys, dual_thresholds)) else: np.save(thresh_save_path, list(zip(keys, auto_thresholds))) args.thresholds = list(zip(keys, auto_thresholds)) else: args.thresholds = None args.classes = keys #save full model with args to restore clf_save_path = os.path.join(save_root, 'model' + '_ep' + str(e) + '.clf') tqdm.write('Saving full classifier to %s' % clf_save_path) torch.save({ 'sd': model.state_dict(), 'args': args }, clf_save_path)
def train_logreg(trX, trY, vaX=None, vaY=None, teX=None, teY=None, penalty='l1', max_iter=100, C=2**np.arange(-8, 1).astype(np.float), seed=42, model=None, eval_test=True, neurons=None, drop_neurons=False, report_metric='acc', automatic_thresholding=False, threshold_metric='acc', micro=False): # if only integer is provided for C make it iterable so we can loop over if not isinstance(C, collections.Iterable): C = list([C]) # extract features for given neuron indices if neurons is not None: if drop_neurons: all_neurons = set(list(range(trX.shape[-1]))) neurons = set(list(neurons)) neurons = list(all_neurons - neurons) trX = trX[:, neurons] if vaX is not None: vaX = vaX[:, neurons] if teX is not None: teX = teX[:, neurons] # Cross validation over C n_classes = 1 if len(trY.shape) > 1: n_classes = trY.shape[-1] scores = [] if model is None: for i, c in enumerate(C): if n_classes <= 1: model = LogisticRegression(C=c, penalty=penalty, max_iter=max_iter, random_state=seed) model.fit(trX, trY) blank_info_dict = { 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0., 'metric': threshold_metric, 'micro': micro } if vaX is not None: info_dict = update_info_dict( blank_info_dict.copy(), vaY, model.predict_proba(vaX)[:, -1]) else: info_dict = update_info_dict( blank_info_dict.copy(), trY, model.predict_proba(trX)[:, -1]) scores.append(get_metric(info_dict)) print(scores[-1]) del model else: info_dicts = [] model = [] for cls in range(n_classes): _model = LogisticRegression(C=c, penalty=penalty, max_iter=max_iter, random_state=seed) _model.fit(trX, trY[:, cls]) blank_info_dict = { 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0., 'metric': threshold_metric, 'micro': micro } if vaX is not None: info_dict = update_info_dict( blank_info_dict.copy(), vaY[:, cls], _model.predict_proba(vaX)[:, -1]) else: info_dict = update_info_dict( blank_info_dict.copy(), trY[:, cls], _model.predict_proba(trX)[:, -1]) info_dicts.append(info_dict) model.append(_model) scores.append(get_metric(info_dicts)) print(scores[-1]) del model c = C[np.argmax(scores)] if n_classes <= 1: model = LogisticRegression(C=c, penalty=penalty, max_iter=max_iter, random_state=seed) model.fit(trX, trY) else: model = [] for cls in range(n_classes): _model = LogisticRegression(C=c, penalty=penalty, max_iter=max_iter, random_state=seed) _model.fit(trX, trY[:, cls]) model.append(_model) else: c = model.C # predict probabilities and get accuracy of regression model on train, val, test as appropriate # also get number of regression weights that are not zero. (number of features used for modeling) scores = [] if n_classes == 1: nnotzero = np.sum(model.coef_ != 0) preds = model.predict_proba(trX)[:, -1] train_score = get_metric( update_info_dict(blank_info_dict.copy(), trY, preds), report_metric) else: nnotzero = 0 preds = [] info_dicts = [] for cls in range(n_classes): nnotzero += np.sum(model[cls].coef_ != 0) _preds = model[cls].predict_proba(trX)[:, -1] info_dicts.append( update_info_dict(blank_info_dict.copy(), trY[:, cls], _preds)) preds.append(_preds) nnotzero /= n_classes train_score = get_metric(info_dicts, report_metric) preds = np.concatenate([p.reshape((-1, 1)) for p in preds], axis=1) scores.append(train_score * 100) if vaX is None: eval_data = trX eval_labels = trY val_score = train_score else: eval_data = vaX eval_labels = vaY if n_classes == 1: preds = model.predict_proba(vaX)[:, -1] val_score = get_metric( update_info_dict(blank_info_dict.copy(), vaY, preds), report_metric) else: preds = [] info_dicts = [] for cls in range(n_classes): _preds = model[cls].predict_proba(vaX)[:, -1] info_dicts.append( update_info_dict(blank_info_dict.copy(), vaY[:, cls], _preds)) preds.append(_preds) val_score = get_metric(info_dicts, report_metric) preds = np.concatenate([p.reshape((-1, 1)) for p in preds], axis=1) val_preds = preds val_labels = eval_labels scores.append(val_score * 100) eval_score = val_score threshold = np.array([.5] * n_classes) if automatic_thresholding: _, threshold, _, _ = _binary_threshold( preds.reshape(-1, n_classes), eval_labels.reshape(-1, n_classes), threshold_metric, micro) threshold = float(threshold.squeeze()) if teX is not None and teY is not None and eval_test: eval_data = teX eval_labels = teY if n_classes == 1: preds = model.predict_proba(eval_data)[:, -1] else: preds = [] for cls in range(n_classes): _preds = model[cls].predict_proba(eval_data)[:, -1] preds.append(_preds) preds = np.concatenate([p.reshape((-1, 1)) for p in preds], axis=1) if n_classes == 1: threshold = float(threshold.squeeze()) eval_score = get_metric( update_info_dict(blank_info_dict.copy(), eval_labels, preds, threshold=threshold), report_metric) else: info_dicts = [] for cls in range(n_classes): info_dicts.append( update_info_dict(blank_info_dict.copy(), eval_labels[:, cls], preds[:, cls], threshold=threshold[cls])) eval_score = get_metric(info_dicts, report_metric) scores.append(eval_score * 100) return model, scores, preds, c, nnotzero
def finetune(model, text, args, val_data=None, LR=None, reg_loss=None, tqdm_desc='nvidia', save_outputs=False, heads_per_class=1, default_threshold=0.5, last_thresholds=[], threshold_validation=True, debug=False): ''' Apply featurization `model` to extract features from text in data loader. Featurization model should return cell state not hidden state. `text` data loader should return tuples of ((text, text length), text label) Returns labels and features for samples in text. ''' # NOTE: If in training mode, do not run in .eval() mode. Bug fixed. if LR is None: model.lm_encoder.eval() model.classifier.eval() else: # Very important to reset back to train mode for future epochs! model.lm_encoder.train() model.classifier.train() # Optionally, freeze language model (train MLP only) # NOTE: un-freeze gradients if they every need to be tweaked in future iterations if args.freeze_lm: for param in model.lm_encoder.parameters(): param.requires_grad = False # Choose which losses to implement if args.use_softmax: if heads_per_class > 1: clf_loss_fn = M.MultiHeadCrossEntropyLoss( heads_per_class=heads_per_class) else: clf_loss_fn = torch.nn.CrossEntropyLoss() else: if heads_per_class > 1: clf_loss_fn = M.MultiHeadBCELoss(heads_per_class=heads_per_class) else: clf_loss_fn = torch.nn.BCELoss() if args.aux_lm_loss: aux_loss_fn = torch.nn.CrossEntropyLoss(reduce=False) else: aux_loss_fn = None if args.thresh_test_preds: thresholds = model.get_thresholds() elif len(last_thresholds) > 0: # Re-use previous thresholds, if provided. # Why? More accurate reporting, and not that slow. Don't compute thresholds on training, for example -- but can recycle val threshold thresholds = last_thresholds else: # Default thresholds -- faster, but less accurate thresholds = np.array([ default_threshold for _ in range(int(model.out_dim / heads_per_class)) ]) total_loss = 0 total_classifier_loss = 0 total_lm_loss = 0 total_multihead_variance_loss = 0 class_accuracies = torch.zeros(model.out_dim).cuda() if model.out_dim / heads_per_class > 1 and not args.use_softmax: keys = list(args.non_binary_cols) elif args.use_softmax: keys = [str(m) for m in range(model.out_dim)] else: keys = [''] info_dicts = [{ 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0, 'metric': args.report_metric, 'micro': args.micro } for k in keys] # Sanity check -- should do this sooner. Does #classes match expected output? assert model.out_dim == len( keys ) * heads_per_class, "model.out_dim does not match keys (%s) x heads_per_class (%d)" % ( keys, heads_per_class) batch_adjustment = 1. / len(text) # Save all outputs *IF* small enough, and requested for thresholding -- basically, on validation #if threshold_validation and LR is not None: all_batches = [] all_stds = [] all_labels = [] for i, data in tqdm(enumerate(text), total=len(text), unit="batch", desc=tqdm_desc, position=1, ncols=100): text_batch, labels_batch, length_batch = get_supervised_batch( data, args.cuda, model, args.max_seq_len, args, heads_per_class=args.heads_per_class) class_out, (lm_out, _) = transform(model, text_batch, labels_batch, length_batch, args, LR) class_std = None if heads_per_class > 1: all_heads, class_out, class_std, clf_out = class_out classifier_loss = clf_loss_fn(all_heads, labels_batch) else: class_out, clf_out = class_out if args.dual_thresh: class_out = class_out[:, :-1] classifier_loss = clf_loss_fn(class_out, labels_batch) if args.use_softmax: class_out = F.softmax(class_out, -1) loss = classifier_loss classifier_loss = classifier_loss.clone() # save for reporting # Also compute multihead variance loss -- from classifier [divide by output size since it scales linearly] if args.aux_head_variance_loss_weight > 0.: multihead_variance_loss = model.classifier.get_last_layer_variance( ) / model.out_dim loss = loss + multihead_variance_loss * args.aux_head_variance_loss_weight # Divide by # batches? Since we're looking at the parameters here, and should be batch independent. # multihead_variance_loss *= batch_adjustment if args.aux_lm_loss: lm_labels = text_batch[1:] lm_losses = aux_loss_fn( lm_out[:-1].view(-1, lm_out.size(2)).contiguous().float(), lm_labels.contiguous().view(-1)) padding_mask = (torch.arange( lm_labels.size(0)).unsqueeze(1).cuda() > length_batch).float() portion_unpadded = padding_mask.sum() / padding_mask.size(0) lm_loss = portion_unpadded * torch.mean( lm_losses * (padding_mask.view(-1).float())) # Scale LM loss -- since it's so big if args.aux_lm_loss_weight > 0.: loss = loss + lm_loss * args.aux_lm_loss_weight # Training if LR is not None: LR.optimizer.zero_grad() loss.backward() LR.optimizer.step() LR.step() # Remove loss from CUDA -- kill gradients and save memory. total_loss += loss.detach().cpu().numpy() if args.use_softmax: labels_batch = onehot(labels_batch.squeeze(), model.out_dim) class_out = onehot(clf_out.view(-1), int(model.out_dim / heads_per_class)) total_classifier_loss += classifier_loss.detach().cpu().numpy() if args.aux_lm_loss: total_lm_loss += lm_loss.detach().cpu().numpy() if args.aux_head_variance_loss_weight > 0: total_multihead_variance_loss += multihead_variance_loss.detach( ).cpu().numpy() for j in range(int(model.out_dim / heads_per_class)): std = None if class_std is not None: std = class_std[:, j] info_dicts[j] = update_info_dict(info_dicts[j], labels_batch[:, j], class_out[:, j], thresholds[j], std=std) # Save, for overall thresholding (not on training) if threshold_validation and LR is None: all_labels.append(labels_batch.detach().cpu().numpy()) all_batches.append(class_out.detach().cpu().numpy()) if class_std is not None: all_stds.append(class_std.detach().cpu().numpy()) if threshold_validation and LR is None: all_batches = np.concatenate(all_batches) all_labels = np.concatenate(all_labels) if heads_per_class > 1: all_stds = np.concatenate(all_stds) # Compute new thresholds -- per class _, thresholds, _, _ = _binary_threshold( all_batches, all_labels, args.threshold_metric, args.micro, global_tweaks=args.global_tweaks) info_dicts = [{ 'fp': 0, 'tp': 0, 'fn': 0, 'tn': 0, 'std': 0., 'metric': args.report_metric, 'micro': args.micro } for k in keys] # In multihead case, look at class averages? Why? More predictive. Works especially well when we force single per-class threshold. for j in range(int(model.out_dim / heads_per_class)): std = None if heads_per_class > 1: std = all_stds[:, j] info_dicts[j] = update_info_dict(info_dicts[j], all_labels[:, j], all_batches[:, j], thresholds[j], std=std) # Metrics for all items -- with current best thresholds total_metrics, class_metric_strs = get_metric_report( info_dicts, args, keys, LR) # Show losses if debug: tqdm.write('losses -- total / classifier / LM / multihead_variance') tqdm.write(total_loss * batch_adjustment) tqdm.write(total_classifier_loss * batch_adjustment) tqdm.write(total_lm_loss * batch_adjustment) tqdm.write(total_multihead_variance_loss * batch_adjustment) return total_loss.item() / ( i + 1), total_metrics, class_metric_strs, thresholds