def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. logs to file and tensorboard. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() if cf.optimizer == "ADAMW": optimizer = torch.optim.AdamW(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) elif cf.optimizer == "SGD": optimizer = torch.optim.SGD(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay), lr=cf.learning_rate[0], momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth") starting_epoch, net, optimizer, model_selector = \ utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector) logger.info('resumed from checkpoint {} to epoch {}'.format( checkpoint_path, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format( epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch[ 'roi_counts'] batch_gen['train'].generator.stats['empty_counts'] += batch[ 'empty_counts'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_( net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping optimizer.step() train_results_list.append( ({k: v for k, v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict if not cf.server_env: print( "\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw)." .format( i + 1, cf.num_train_batches, logger.get_time("train_batch_loadfw") + logger.get_time("train_batch_netfw") + logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw", reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch - 1) % cf.plot_frequency == 0: # view an example batch utils.split_off_process( plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="train-example plot", out_file=os.path.join( cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions( train_results_list, monitor_metrics['train']) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format( 'patient' if cf.val_mode == 'val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- if (epoch - 1) % cf.plot_frequency == 0: utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="val-example plot", out_file=os.path.join( cf.plot_dir, 'batch_example_val_{}.png'.format( cf.fold))) logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions( val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({ "lr": { str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups) } }) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info( 'finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.' .format( epoch, cf.num_epochs, logger.get_time("train_epoch") + logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True) / cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True) / batch_gen["n_val"])) logger.info("time for evals: {:.2f}s".format( logger.get_time("evals", reset=True))) #-------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1] logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format( cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True)
def return_metrics(self, monitor_metrics=None): """ calculates AP/AUC scores for internal dataframe. called directly from evaluate_predictions during training for monitoring, or from score_test_df during inference (for single folds or aggregated test set). Loops over foreground classes and score_levels (typically 'roi' and 'patient'), gets scores and stores them. Optionally creates plots of prediction histograms and roc/prc curves. :param monitor_metrics: dict of dicts with all metrics of previous epochs. this function adds metrics for current epoch and returns the same object. :return: all_stats: list. Contains dicts with resulting scores for each combination of foreground class and score_level. :return: monitor_metrics """ # -------------- monitoring independent of class, score level ------------ if monitor_metrics is not None: for l_name in self.epoch_losses: monitor_metrics[l_name] = [self.epoch_losses[l_name]] df = self.test_df all_stats = [] for cl in list(self.cf.class_dict.keys()): cl_df = df[df.pred_class == cl] for score_level in self.cf.report_score_level: stats_dict = {} stats_dict['name'] = 'fold_{} {} cl_{}'.format( self.cf.fold, score_level, cl) if score_level == 'rois': # kick out dummy entries for true negative patients. not needed on roi-level. spec_df = cl_df[cl_df.det_type != 'patient_tn'] stats_dict['ap'] = get_roi_ap_from_df([ spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap ]) # AUC not sensible on roi-level, since true negative box predictions do not exist. Would reward # higher amounts of low confidence false positives. stats_dict['auc'] = np.nan stats_dict['roc'] = np.nan stats_dict['prc'] = np.nan # for the aggregated test set case, additionally get the scores for averaging over fold results. if len(df.fold.unique()) > 1: aps = [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] aps.append( get_roi_ap_from_df([ fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap ])) stats_dict['mean_ap'] = np.mean(aps) stats_dict['mean_auc'] = 0 # on patient level, aggregate predictions per patient (pid): The patient predicted score is the highest # confidence prediction for this class. The patient class label is 1 if roi of this class exists in patient, else 0. if score_level == 'patient': spec_df = cl_df.groupby(['pid'], as_index=False).agg({ 'class_label': 'max', 'pred_score': 'max', 'fold': 'first' }) if len(spec_df.class_label.unique()) > 1: stats_dict['auc'] = roc_auc_score( spec_df.class_label.tolist(), spec_df.pred_score.tolist()) stats_dict['roc'] = roc_curve( spec_df.class_label.tolist(), spec_df.pred_score.tolist()) else: stats_dict['auc'] = np.nan stats_dict['roc'] = np.nan if (spec_df.class_label == 1).any(): stats_dict['ap'] = average_precision_score( spec_df.class_label.tolist(), spec_df.pred_score.tolist()) stats_dict['prc'] = precision_recall_curve( spec_df.class_label.tolist(), spec_df.pred_score.tolist()) else: stats_dict['ap'] = np.nan stats_dict['prc'] = np.nan # for the aggregated test set case, additionally get the scores for averaging over fold results. if len(df.fold.unique()) > 1: aucs = [] aps = [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] if len(fold_df.class_label.unique()) > 1: aucs.append( roc_auc_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist())) if (fold_df.class_label == 1).any(): aps.append( average_precision_score( fold_df.class_label.tolist(), fold_df.pred_score.tolist())) stats_dict['mean_auc'] = np.mean(aucs) stats_dict['mean_ap'] = np.mean(aps) # fill new results into monitor_metrics dict. for simplicity, only one class (of interest) is monitored on patient level. if monitor_metrics is not None and not ( score_level == 'patient' and cl != self.cf.patient_class_of_interest): score_level_name = 'patient' if score_level == 'patient' else self.cf.class_dict[ cl] monitor_metrics[score_level_name + '_ap'].append( stats_dict['ap'] if stats_dict['ap'] > 0 else np.nan) if score_level == 'patient': monitor_metrics[score_level_name + '_auc'].append( stats_dict['auc'] if stats_dict['auc'] > 0 else np. nan) if self.cf.plot_prediction_histograms: out_filename = os.path.join( self.hist_dir, 'pred_hist_{}_{}_{}_cl{}'.format( self.cf.fold, 'val' if 'val' in self.mode else self.mode, score_level, cl)) type_list = None if score_level == 'patient' else spec_df.det_type.tolist( ) utils.split_off_process(plotting.plot_prediction_hist, spec_df.class_label.tolist(), spec_df.pred_score.tolist(), type_list, out_filename) all_stats.append(stats_dict) # analysis of the hyper-parameter cf.min_det_thresh, for optimization on validation set. if self.cf.scan_det_thresh: conf_threshs = list(np.arange(0.9, 1, 0.01)) pool = Pool(processes=10) mp_inputs = [[spec_df, ii, self.cf.per_patient_ap] for ii in conf_threshs] aps = pool.map(get_roi_ap_from_df, mp_inputs, chunksize=1) pool.close() pool.join() self.logger.info('results from scanning over det_threshs:', [[i, j] for i, j in zip(conf_threshs, aps)]) if self.cf.plot_stat_curves: out_filename = os.path.join( self.curves_dir, '{}_{}_stat_curves'.format(self.cf.fold, self.mode)) utils.split_off_process(plotting.plot_stat_curves, all_stats, out_filename) # get average stats over foreground classes on roi level. avg_ap = np.mean([d['ap'] for d in all_stats if 'rois' in d['name']]) all_stats.append({ 'name': 'average_foreground_roi', 'auc': 0, 'ap': avg_ap }) if len(df.fold.unique()) > 1: avg_mean_ap = np.mean( [d['mean_ap'] for d in all_stats if 'rois' in d['name']]) all_stats[-1]['mean_ap'] = avg_mean_ap all_stats[-1]['mean_auc'] = 0 # in small data sets, values of model_selection_criterion can be identical across epochs, wich breaks the # ranking of model_selector. Thus, pertube identical values by a neglectibale random term. for sc in self.cf.model_selection_criteria: if 'val' in self.mode and monitor_metrics[sc].count( monitor_metrics[sc] [-1]) > 1 and monitor_metrics[sc][-1] is not None: monitor_metrics[sc][-1] += 1e-6 * np.random.rand() return all_stats, monitor_metrics
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) net = model.net(cf, logger).cuda() if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam": logger.info("Using Adam optimizer.") optimizer = torch.optim.Adam(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) else: logger.info("Using AdamW optimizer.") optimizer = torch.optim.AdamW(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) starting_epoch = 1 # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint") starting_epoch, net, optimizer, monitor_metrics = \ utils.load_checkpoint(checkpoint_path, net, optimizer) logger.info('resumed from checkpoint {} to epoch {}'.format( checkpoint_path, starting_epoch)) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # Prepare MLFlow best_loss = 1e3 step = 1 mlflow.log_artifacts(cf.exp_dir, "exp") for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) start_time = time.time() net.train() train_results_list = [] bix = 0 seen_pids = [] while True: bix = bix + 1 try: batch = next(batch_gen['train']) except StopIteration: break for pid in batch['pid']: seen_pids.append(pid) # print(f'\rtr. batch {bix}: {batch["pid"]}') tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward() optimizer.step() print( '\rtr. batch {0} (ep. {1}) fw {2:.2f}s / bw {3:.2f} s / total {4:.2f} s || ' .format(bix + 1, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'], flush=True, end="") train_results_list.append( ({k: v for k, v in results_dict.items() if k != "seg_preds"}, batch["pid"])) print(f"Seen pids (unique): {len(np.unique(seen_pids))}") print() _, monitor_metrics['train'] = train_evaluator.evaluate_predictions( train_results_list, monitor_metrics['train']) logger.info('generating training example plot.') utils.split_off_process( plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(cf.plot_dir, 'pred_example_{}_train.png'.format(cf.fold))) train_time = time.time() - start_time logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') while True: try: batch = next(batch_gen[cf.val_mode]) except StopIteration: break if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append(({ k: v for k, v in results_dict.items() if k != "seg_preds" }, batch["pid"])) _, monitor_metrics['val'] = val_evaluator.evaluate_predictions( val_results_list, monitor_metrics['val']) best_model_path = model_selector.run_model_selection( net, optimizer, monitor_metrics, epoch) # Save best model mlflow.log_artifacts( best_model_path, os.path.join("exp", os.path.basename(cf.fold_dir), 'best_checkpoint')) # Save logs and plots mlflow.log_artifacts(os.path.join(cf.exp_dir, "logs"), os.path.join("exp", 'logs')) mlflow.log_artifacts( cf.plot_dir, os.path.join("exp", os.path.basename(cf.plot_dir))) # update monitoring and prediction plots monitor_metrics.update({ "lr": { str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups) } }) # replace tboard metrics with MLFlow #logger.metrics2tboard(monitor_metrics, global_step=epoch) mlflow.log_metric('learning rate', optimizer.param_groups[0]['lr'], cf.num_epochs * cf.fold + epoch) for key in ['train', 'val']: for tag, val in monitor_metrics[key].items(): val = val[ -1] # maybe remove list wrapping, recording in evaluator? if 'loss' in tag.lower() and not np.isnan(val): mlflow.log_metric(f'{key}_{tag}', val, cf.num_epochs * cf.fold + epoch) elif not np.isnan(val): mlflow.log_metric(f'{key}_{tag}', val, cf.num_epochs * cf.fold + epoch) epoch_time = time.time() - start_time logger.info('trained epoch {}: took {} ({} train / {} val)'.format( epoch, utils.get_formatted_duration(epoch_time, "ms"), utils.get_formatted_duration(train_time, "ms"), utils.get_formatted_duration(epoch_time - train_time, "ms"))) batch = next(batch_gen['val_sampling']) results_dict = net.train_forward(batch, is_validation=True) logger.info('generating validation-sampling example plot.') utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join( cf.plot_dir, 'pred_example_{}_val.png'.format( cf.fold))) # -------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1] # Save whole experiment to MLFlow mlflow.log_artifacts(cf.exp_dir, "exp")
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) net = model.net(cf, logger).cuda() if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam": logger.info("Using Adam optimizer.") optimizer = torch.optim.Adam(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) else: logger.info("Using AdamW optimizer.") optimizer = torch.optim.AdamW(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) starting_epoch = 1 # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint") starting_epoch, net, optimizer, monitor_metrics = \ utils.load_checkpoint(checkpoint_path, net, optimizer) logger.info('resumed from checkpoint {} to epoch {}'.format( checkpoint_path, starting_epoch)) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) start_time = time.time() net.train() train_results_list = [] for bix in range(cf.num_train_batches): batch = next(batch_gen['train']) tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward() optimizer.step() print( '\rtr. batch {0}/{1} (ep. {2}) fw {3:.2f}s / bw {4:.2f} s / total {5:.2f} s || ' .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'], flush=True, end="") train_results_list.append( ({k: v for k, v in results_dict.items() if k != "seg_preds"}, batch["pid"])) print() _, monitor_metrics['train'] = train_evaluator.evaluate_predictions( train_results_list, monitor_metrics['train']) logger.info('generating training example plot.') utils.split_off_process( plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(cf.plot_dir, 'pred_example_{}_train.png'.format(cf.fold))) train_time = time.time() - start_time logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') for _ in range(batch_gen['n_val']): batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) #val_results_list.append([results_dict['boxes'], batch['pid']]) val_results_list.append(({ k: v for k, v in results_dict.items() if k != "seg_preds" }, batch["pid"])) _, monitor_metrics['val'] = val_evaluator.evaluate_predictions( val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots monitor_metrics.update({ "lr": { str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups) } }) logger.metrics2tboard(monitor_metrics, global_step=epoch) epoch_time = time.time() - start_time logger.info('trained epoch {}: took {} ({} train / {} val)'.format( epoch, utils.get_formatted_duration(epoch_time, "ms"), utils.get_formatted_duration(train_time, "ms"), utils.get_formatted_duration(epoch_time - train_time, "ms"))) batch = next(batch_gen['val_sampling']) results_dict = net.train_forward(batch, is_validation=True) logger.info('generating validation-sampling example plot.') utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join( cf.plot_dir, 'pred_example_{}_val.png'.format( cf.fold))) # -------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1]