def eval_model(test_data, models, args): ''' Run model on test data, and return test stats (includes loss accuracy, etc) ''' if not isinstance(models, dict): models = {'model': models} if args.cuda: models['model'] = models['model'].cuda() batch_size = args.batch_size // args.batch_splits test_stats = init_metrics_dictionary(modes=['test']) data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=ignore_None_collate, pin_memory=True, drop_last=False) loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, meta_loss, reg_loss, precision, recall, f1 = run_epoch( data_loader, train_model=False, truncate_epoch=False, models=models, optimizers=None, args=args) log_statement, test_stats = compute_eval_metrics( args, loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, meta_loss, reg_loss, precision, recall, f1, test_stats, 'test') print(log_statement) return test_stats
def compute_threshold_and_dev_stats(dev_data, models, args): ''' Compute threshold based on the Dev results ''' if not isinstance(models, dict): models = {'model': models} if args.cuda: models['model'] = models['model'].cuda() dev_stats = init_metrics_dictionary(modes=['dev']) batch_size = args.batch_size // args.batch_splits data_loader = torch.utils.data.DataLoader(dev_data, batch_size=batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=ignore_None_collate, pin_memory=True, drop_last=False) loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, reg_loss, precision, recall, f1 = run_epoch( data_loader, train_model=False, truncate_epoch=False, models=models, optimizers=None, args=args) human_preds = get_human_preds(exams, dev_data.metadata_json) threshold, (th_lb, th_ub) = stats.get_thresholds_interval( probs, golds, human_preds, rebalance_eval_cancers=args.rebalance_eval_cancers) args.threshold = threshold print(' Dev Threshold: {:.8f} ({:.8f} - {:.8f})'.format( threshold, th_lb, th_ub)) log_statement, dev_stats = compute_eval_metrics( args, loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, reg_loss, precision, recall, f1, dev_stats, 'dev') print(log_statement) return dev_stats
def compute_threshold_and_dev_stats(dev_data, models, args): ''' Compute threshold based on the Dev results ''' if not isinstance(models, dict): models = {'model': models} models['model'] = models['model'].to(args.device) dev_stats = init_metrics_dictionary(modes=['dev']) batch_size = args.batch_size // args.batch_splits data_loader = torch.utils.data.DataLoader( dev_data, batch_size = batch_size, shuffle = False, num_workers = args.num_workers, collate_fn = ignore_None_collate, pin_memory=True, drop_last = False) loss, golds, preds, probs, exams, reg_loss, censor_times, adv_loss = run_epoch( data_loader, train_model=False, truncate_epoch=False, models=models, optimizers=None, args=args) if ('detection' in args.dataset or 'risk' in args.dataset) and '1year' in args.dataset and not args.survival_analysis_setup: human_preds = get_human_preds(exams, dev_data.metadata_json) threshold, (th_lb, th_ub) = stats.get_thresholds_interval(probs, golds, human_preds, rebalance_eval_cancers=args.rebalance_eval_cancers) args.threshold = threshold print(' Dev Threshold: {:.8f} ({:.8f} - {:.8f})'.format(threshold, th_lb, th_ub)) else: args.threshold = None log_statement, dev_stats = compute_eval_metrics( args, loss, golds, preds, probs, exams, reg_loss, censor_times, adv_loss, dev_stats, 'dev') print(log_statement) return dev_stats
def train_model(train_data, dev_data, model, args): ''' Train model and tune on dev set. If model doesn't improve dev performance within args.patience epochs, then halve the learning rate, restore the model to best and continue training. At the end of training, the function will restore the model to best dev version. returns epoch_stats: a dictionary of epoch level metrics for train and test returns models : dict of models, containing best performing model setting from this call to train ''' start_epoch, epoch_stats, state_keeper, batch_size, models, optimizers, tuning_key, num_epoch_sans_improvement, num_epoch_since_reducing_lr, no_tuning_on_dev = get_train_variables( args, model) train_data_loader, dev_data_loader = get_train_and_dev_dataset_loaders( args, train_data, dev_data, batch_size) for epoch in range(start_epoch, args.epochs + 1): print("-------------\nEpoch {}:\n".format(epoch)) for mode, data_loader in [('Train', train_data_loader), ('Dev', dev_data_loader)]: train_model = mode == 'Train' key_prefix = mode.lower() loss, golds, preds, probs, exams, reg_loss, censor_times, adv_loss = run_epoch( data_loader, train_model=train_model, truncate_epoch=True, models=models, optimizers=optimizers, args=args) log_statement, epoch_stats = compute_eval_metrics(args, loss, golds, preds, probs, exams, reg_loss, censor_times, adv_loss, epoch_stats, key_prefix) if mode == 'Dev' and 'mammo_1year' in args.dataset: dev_human_preds = get_human_preds(exams, dev_data.metadata_json) threshold, _ = stats.get_thresholds_interval(probs, golds, dev_human_preds, rebalance_eval_cancers=args.rebalance_eval_cancers, num_resamples=NUM_RESAMPLES_DURING_TRAIN) print(' Dev Threshold: {:.8f} '.format(threshold)) (fnr, _), (tpr, _), (tnr, _) = stats.get_rates_intervals(probs, golds, threshold, rebalance_eval_cancers=args.rebalance_eval_cancers, num_resamples=NUM_RESAMPLES_DURING_TRAIN) epoch_stats['{}_fnr'.format(key_prefix)].append(fnr) epoch_stats['{}_tnr'.format(key_prefix)].append(tnr) epoch_stats['{}_tpr'.format(key_prefix)].append(tpr) log_statement = "{} fnr: {:.3f} tnr: {:.3f} tpr: {:.3f}".format(log_statement, fnr, tnr, tpr) print(log_statement) # Save model if beats best dev, or if not tuning on dev best_func, arg_best = (min, np.argmin) if tuning_key == 'dev_loss' else (max, np.argmax) improved = best_func(epoch_stats[tuning_key]) == epoch_stats[tuning_key][-1] if improved or no_tuning_on_dev: num_epoch_sans_improvement = 0 if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) epoch_stats['best_epoch'] = arg_best( epoch_stats[tuning_key] ) state_keeper.save(models, optimizers, epoch, args.lr, epoch_stats) num_epoch_since_reducing_lr += 1 if improved: num_epoch_sans_improvement = 0 else: num_epoch_sans_improvement += 1 print('---- Best Dev {} is {} at epoch {}'.format( args.tuning_metric, epoch_stats[tuning_key][epoch_stats['best_epoch']], epoch_stats['best_epoch'] + 1)) if num_epoch_sans_improvement >= args.patience or \ (no_tuning_on_dev and num_epoch_since_reducing_lr >= args.lr_reduction_interval): print("Reducing learning rate") num_epoch_sans_improvement = 0 num_epoch_since_reducing_lr = 0 if not args.turn_off_model_reset: models, optimizer_states, _, _, _ = state_keeper.load() # Reset optimizers for name in optimizers: optimizer = optimizers[name] state_dict = optimizer_states[name] optimizers[name] = state_keeper.load_optimizer(optimizer, state_dict) # Reduce LR for name in optimizers: optimizer = optimizers[name] for param_group in optimizer.param_groups: param_group['lr'] *= args.lr_decay # Update lr also in args for resumable usage args.lr *= .5 # Restore model to best dev performance, or last epoch when not tuning on dev models, _, _, _, _ = state_keeper.load() return epoch_stats, models