def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) net = model.net(cf, logger).cuda() optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) starting_epoch = 1 # prepare monitoring monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf) if cf.resume_to_checkpoint: starting_epoch, monitor_metrics = utils.load_checkpoint( cf.resume_to_checkpoint, net, optimizer) logger.info('resumed to checkpoint {} at epoch {}'.format( cf.resume_to_checkpoint, starting_epoch)) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1] start_time = time.time() net.train() train_results_list = [] for bix in range(cf.num_train_batches): batch = next(batch_gen['train']) tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward() optimizer.step() logger.info( 'tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || ' .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string']) train_results_list.append([results_dict['boxes'], batch['pid']]) monitor_metrics['train']['monitor_values'][epoch].append( results_dict['monitor_values']) _, monitor_metrics['train'] = train_evaluator.evaluate_predictions( train_results_list, monitor_metrics['train']) train_time = time.time() - start_time logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') for _ in range(batch_gen['n_val']): batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append( [results_dict['boxes'], batch['pid']]) monitor_metrics['val']['monitor_values'][epoch].append( results_dict['monitor_values']) _, monitor_metrics['val'] = val_evaluator.evaluate_predictions( val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots TrainingPlot.update_and_save(monitor_metrics, epoch) epoch_time = time.time() - start_time logger.info( 'trained epoch {}: took {} sec. ({} train / {} val)'.format( epoch, epoch_time, train_time, epoch_time - train_time)) batch = next(batch_gen['val_sampling']) results_dict = net.train_forward(batch, is_validation=True) logger.info('plotting predictions from validation sampling.') plot_batch_prediction(batch, results_dict, cf)
def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. logs to file and tensorboard. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() if cf.optimizer == "ADAMW": optimizer = torch.optim.AdamW(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) elif cf.optimizer == "SGD": optimizer = torch.optim.SGD(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay), lr=cf.learning_rate[0], momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth") starting_epoch, net, optimizer, model_selector = \ utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector) logger.info('resumed from checkpoint {} to epoch {}'.format( checkpoint_path, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format( epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch[ 'roi_counts'] batch_gen['train'].generator.stats['empty_counts'] += batch[ 'empty_counts'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_( net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping optimizer.step() train_results_list.append( ({k: v for k, v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict if not cf.server_env: print( "\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw)." .format( i + 1, cf.num_train_batches, logger.get_time("train_batch_loadfw") + logger.get_time("train_batch_netfw") + logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw", reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch - 1) % cf.plot_frequency == 0: # view an example batch utils.split_off_process( plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="train-example plot", out_file=os.path.join( cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions( train_results_list, monitor_metrics['train']) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format( 'patient' if cf.val_mode == 'val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- if (epoch - 1) % cf.plot_frequency == 0: utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="val-example plot", out_file=os.path.join( cf.plot_dir, 'batch_example_val_{}.png'.format( cf.fold))) logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions( val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({ "lr": { str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups) } }) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info( 'finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.' .format( epoch, cf.num_epochs, logger.get_time("train_epoch") + logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True) / cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True) / batch_gen["n_val"])) logger.info("time for evals: {:.2f}s".format( logger.get_time("evals", reset=True))) #-------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1] logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format( cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True)
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) net = model.net(cf, logger).cuda() if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam": logger.info("Using Adam optimizer.") optimizer = torch.optim.Adam(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) else: logger.info("Using AdamW optimizer.") optimizer = torch.optim.AdamW(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) starting_epoch = 1 # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint") starting_epoch, net, optimizer, monitor_metrics = \ utils.load_checkpoint(checkpoint_path, net, optimizer) logger.info('resumed from checkpoint {} to epoch {}'.format( checkpoint_path, starting_epoch)) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # Prepare MLFlow best_loss = 1e3 step = 1 mlflow.log_artifacts(cf.exp_dir, "exp") for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) start_time = time.time() net.train() train_results_list = [] bix = 0 seen_pids = [] while True: bix = bix + 1 try: batch = next(batch_gen['train']) except StopIteration: break for pid in batch['pid']: seen_pids.append(pid) # print(f'\rtr. batch {bix}: {batch["pid"]}') tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward() optimizer.step() print( '\rtr. batch {0} (ep. {1}) fw {2:.2f}s / bw {3:.2f} s / total {4:.2f} s || ' .format(bix + 1, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'], flush=True, end="") train_results_list.append( ({k: v for k, v in results_dict.items() if k != "seg_preds"}, batch["pid"])) print(f"Seen pids (unique): {len(np.unique(seen_pids))}") print() _, monitor_metrics['train'] = train_evaluator.evaluate_predictions( train_results_list, monitor_metrics['train']) logger.info('generating training example plot.') utils.split_off_process( plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(cf.plot_dir, 'pred_example_{}_train.png'.format(cf.fold))) train_time = time.time() - start_time logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') while True: try: batch = next(batch_gen[cf.val_mode]) except StopIteration: break if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append(({ k: v for k, v in results_dict.items() if k != "seg_preds" }, batch["pid"])) _, monitor_metrics['val'] = val_evaluator.evaluate_predictions( val_results_list, monitor_metrics['val']) best_model_path = model_selector.run_model_selection( net, optimizer, monitor_metrics, epoch) # Save best model mlflow.log_artifacts( best_model_path, os.path.join("exp", os.path.basename(cf.fold_dir), 'best_checkpoint')) # Save logs and plots mlflow.log_artifacts(os.path.join(cf.exp_dir, "logs"), os.path.join("exp", 'logs')) mlflow.log_artifacts( cf.plot_dir, os.path.join("exp", os.path.basename(cf.plot_dir))) # update monitoring and prediction plots monitor_metrics.update({ "lr": { str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups) } }) # replace tboard metrics with MLFlow #logger.metrics2tboard(monitor_metrics, global_step=epoch) mlflow.log_metric('learning rate', optimizer.param_groups[0]['lr'], cf.num_epochs * cf.fold + epoch) for key in ['train', 'val']: for tag, val in monitor_metrics[key].items(): val = val[ -1] # maybe remove list wrapping, recording in evaluator? if 'loss' in tag.lower() and not np.isnan(val): mlflow.log_metric(f'{key}_{tag}', val, cf.num_epochs * cf.fold + epoch) elif not np.isnan(val): mlflow.log_metric(f'{key}_{tag}', val, cf.num_epochs * cf.fold + epoch) epoch_time = time.time() - start_time logger.info('trained epoch {}: took {} ({} train / {} val)'.format( epoch, utils.get_formatted_duration(epoch_time, "ms"), utils.get_formatted_duration(train_time, "ms"), utils.get_formatted_duration(epoch_time - train_time, "ms"))) batch = next(batch_gen['val_sampling']) results_dict = net.train_forward(batch, is_validation=True) logger.info('generating validation-sampling example plot.') utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join( cf.plot_dir, 'pred_example_{}_val.png'.format( cf.fold))) # -------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1] # Save whole experiment to MLFlow mlflow.log_artifacts(cf.exp_dir, "exp")
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) net = model.net(cf, logger).cuda() if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam": logger.info("Using Adam optimizer.") optimizer = torch.optim.Adam(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) else: logger.info("Using AdamW optimizer.") optimizer = torch.optim.AdamW(utils.parse_params_for_optim( net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) starting_epoch = 1 # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint") starting_epoch, net, optimizer, monitor_metrics = \ utils.load_checkpoint(checkpoint_path, net, optimizer) logger.info('resumed from checkpoint {} to epoch {}'.format( checkpoint_path, starting_epoch)) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) start_time = time.time() net.train() train_results_list = [] for bix in range(cf.num_train_batches): batch = next(batch_gen['train']) tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward() optimizer.step() print( '\rtr. batch {0}/{1} (ep. {2}) fw {3:.2f}s / bw {4:.2f} s / total {5:.2f} s || ' .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'], flush=True, end="") train_results_list.append( ({k: v for k, v in results_dict.items() if k != "seg_preds"}, batch["pid"])) print() _, monitor_metrics['train'] = train_evaluator.evaluate_predictions( train_results_list, monitor_metrics['train']) logger.info('generating training example plot.') utils.split_off_process( plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(cf.plot_dir, 'pred_example_{}_train.png'.format(cf.fold))) train_time = time.time() - start_time logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') for _ in range(batch_gen['n_val']): batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) #val_results_list.append([results_dict['boxes'], batch['pid']]) val_results_list.append(({ k: v for k, v in results_dict.items() if k != "seg_preds" }, batch["pid"])) _, monitor_metrics['val'] = val_evaluator.evaluate_predictions( val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots monitor_metrics.update({ "lr": { str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups) } }) logger.metrics2tboard(monitor_metrics, global_step=epoch) epoch_time = time.time() - start_time logger.info('trained epoch {}: took {} ({} train / {} val)'.format( epoch, utils.get_formatted_duration(epoch_time, "ms"), utils.get_formatted_duration(train_time, "ms"), utils.get_formatted_duration(epoch_time - train_time, "ms"))) batch = next(batch_gen['val_sampling']) results_dict = net.train_forward(batch, is_validation=True) logger.info('generating validation-sampling example plot.') utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join( cf.plot_dir, 'pred_example_{}_val.png'.format( cf.fold))) # -------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1]
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) writer = SummaryWriter(os.path.join(cf.exp_dir, 'tensorboard')) net = model.net(cf, logger).cuda() #optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) optimizer = torch.optim.Adam(net.parameters(), lr=cf.initial_learning_rate, weight_decay=cf.weight_decay) model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) #val_sampling starting_epoch = 1 # prepare monitoring if cf.resume_to_checkpoint: #default: False lastepochpth = cf.resume_to_checkpoint + 'last_checkpoint/' best_epoch = np.load(lastepochpth + 'epoch_ranking.npy')[0] df = open(lastepochpth + 'monitor_metrics.pickle', 'rb') monitor_metrics = pickle.load(df) df.close() starting_epoch = utils.load_checkpoint(lastepochpth, net, optimizer) logger.info('resumed to checkpoint {} at epoch {}'.format( cf.resume_to_checkpoint, starting_epoch)) num_batch = starting_epoch * cf.num_train_batches + 1 num_val = starting_epoch * cf.num_val_batches + 1 else: monitor_metrics = utils.prepare_monitoring(cf) num_batch = 0 #for show loss num_val = 0 logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) best_train_recall, best_val_recall = 0, 0 lr_now = cf.initial_learning_rate for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) for param_group in optimizer.param_groups: #param_group['lr'] = cf.learning_rate[epoch - 1] print('lr_now', lr_now) lr_next = utils.learning_rate_decreasing( cf, epoch, lr_now, mode='step') #cf.learning_rate[epoch - 1] print('lr_next', lr_next) param_group[ 'lr'] = lr_next #learning_rate_decreasing(cf,epoch,lr_now,mode='step')#cf.learning_rate[epoch - 1] lr_now = lr_next start_time = time.time() net.train() train_results_list = [] #this batch train_results_list_seg = [] for bix in range(cf.num_train_batches): #200 num_batch += 1 batch = next( batch_gen['train'] ) #data,seg,pid,class_target,bb_target,roi_masks,roi_labels for ii, i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward() #total loss optimizer.step() if (num_batch) % cf.show_train_images == 0: fig = plot_batch_prediction(batch, results_dict, cf, 'train') writer.add_figure('/Train/results', fig, num_batch) fig.clear() print('model', cf.exp_dir.split('/')[-2]) logger.info( 'tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || ' .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw)) #writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch) #writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch) #writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch) #writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch) #writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch) #writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch) #writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch) #writer.add_scalar('Train/fusion_dice_loss',results_dict['monitor_losses']['fusion_loss_dice'],num_batch) train_results_list.append([results_dict['boxes'], batch['pid']]) #just gt and det monitor_metrics['train']['monitor_values'][epoch].append( results_dict['monitor_losses']) count_train = train_evaluator.evaluate_predictions(train_results_list, epoch, cf, flag='train') precision = count_train[0] / (count_train[0] + count_train[2] + 0.01) recall = count_train[0] / (count_train[3]) print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format( count_train[0], count_train[1], count_train[2], count_train[3])) print('precision:{}, recall:{}'.format(precision, recall)) monitor_metrics['train']['train_recall'].append(recall) monitor_metrics['train']['train_percision'].append(precision) writer.add_scalar('Train/train_precision', precision, epoch) writer.add_scalar('Train/train_recall', recall, epoch) train_time = time.time() - start_time logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') dice_val_seg, dice_val_mask, dice_val_fusion = [], [], [] for _ in range(batch_gen['n_val']): #50 num_val += 1 batch = next(batch_gen[cf.val_mode]) print('eval', batch['pid']) for ii, i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient( batch) #result of one patient elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) if (num_val) % cf.show_val_images == 0: fig = plot_batch_prediction(batch, results_dict, cf, cf.val_mode) writer.add_figure('Val/results', fig, num_val) fig.clear() # compute dice for vnet this_batch_seg_label = torch.FloatTensor( mutils.get_one_hot_encoding( batch['seg'], cf.num_seg_classes + 1)).cuda() if cf.fusion_feature_method == 'after': this_batch_dice_seg = mutils.dice_val( results_dict['seg_logits'], this_batch_seg_label) else: this_batch_dice_seg = mutils.dice_val( F.softmax(results_dict['seg_logits'], dim=1), this_batch_seg_label) dice_val_seg.append(this_batch_dice_seg) # compute dice for mask #mask_map = torch.from_numpy(results_dict['seg_preds']).cuda() if cf.fusion_feature_method == 'after': this_batch_dice_mask = mutils.dice_val( results_dict['seg_preds'], this_batch_seg_label) else: this_batch_dice_mask = mutils.dice_val( F.softmax(results_dict['seg_preds'], dim=1), this_batch_seg_label) dice_val_mask.append(this_batch_dice_mask) # compute dice for fusion if cf.fusion_feature_method == 'after': this_batch_dice_fusion = mutils.dice_val( results_dict['fusion_map'], this_batch_seg_label) else: this_batch_dice_fusion = mutils.dice_val( F.softmax(results_dict['fusion_map'], dim=1), this_batch_seg_label) dice_val_fusion.append(this_batch_dice_fusion) val_results_list.append( [results_dict['boxes'], batch['pid']]) monitor_metrics['val']['monitor_values'][epoch].append( results_dict['monitor_values']) count_val = val_evaluator.evaluate_predictions( val_results_list, epoch, cf, flag='val') precision = count_val[0] / (count_val[0] + count_val[2] + 0.01) recall = count_val[0] / (count_val[3]) print( 'tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format( count_val[0], count_val[1], count_val[2], count_val[3])) print('precision:{}, recall:{}'.format(precision, recall)) val_dice_seg = sum(dice_val_seg) / float(len(dice_val_seg)) val_dice_mask = sum(dice_val_mask) / float(len(dice_val_mask)) val_dice_fusion = sum(dice_val_fusion) / float( len(dice_val_fusion)) monitor_metrics['val']['val_recall'].append(recall) monitor_metrics['val']['val_precision'].append(precision) monitor_metrics['val']['val_dice_seg'].append(val_dice_seg) monitor_metrics['val']['val_dice_mask'].append(val_dice_mask) monitor_metrics['val']['val_dice_fusion'].append( val_dice_fusion) writer.add_scalar('Val/val_precision', precision, epoch) writer.add_scalar('Val/val_recall', recall, epoch) writer.add_scalar('Val/val_dice_seg', val_dice_seg, epoch) writer.add_scalar('Val/val_dice_mask', val_dice_mask, epoch) writer.add_scalar('Val/val_dice_fusion', val_dice_fusion, epoch) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots #TrainingPlot.update_and_save(monitor_metrics, epoch) epoch_time = time.time() - start_time logger.info( 'trained epoch {}: took {} sec. ({} train / {} val)'.format( epoch, epoch_time, train_time, epoch_time - train_time)) writer.close()
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) writer = SummaryWriter(os.path.join(cf.exp_dir,'tensorboard')) net = model.net(cf, logger).cuda() #print('finish initial network') optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) #print('finish initial optimizer') model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)#val_sampling starting_epoch = 1 # prepare monitoring #monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf) #print('monitor_metrics',monitor_metrics) if cf.resume_to_checkpoint:#default: False best_epoch = np.load(cf.resume_to_checkpoint + 'epoch_ranking.npy')[0] df = open(cf.resume_to_checkpoint+'monitor_metrics.pickle','rb') monitor_metrics = pickle.load(df) df.close() starting_epoch = utils.load_checkpoint(cf.resume_to_checkpoint, net, optimizer) logger.info('resumed to checkpoint {} at epoch {}'.format(cf.resume_to_checkpoint, starting_epoch)) num_batch = starting_epoch * cf.num_train_batches+1 num_val = starting_epoch * cf.num_val_batches+1 else: monitor_metrics = utils.prepare_monitoring(cf) num_batch = 0#for show loss num_val = 0 logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) #for k in batch_gen.keys(): # print('k in batch_gen are {}'.format(k)) best_train_recall,best_val_recall = 0,0 for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1] start_time = time.time() net.train() train_results_list = []#this batch #print('net.train()') for bix in range(cf.num_train_batches):#200 num_batch += 1 batch = next(batch_gen['train'])#data,seg,pid,class_target,bb_target,roi_masks,roi_labels #print('training',batch['pid']) for ii,i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] #for k in batch.keys(): # print('k',k) tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward()#total loss optimizer.step() if (num_batch) % cf.show_train_images == 0: fig = plot_batch_prediction(batch, results_dict, cf,'train') writer.add_figure('/Train/results',fig,num_batch) fig.clear() logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || ' .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string']) writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch) writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch) writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch) writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch) writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch) if 'mrcnn' in cf.model_path: writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch) if 'ufrcnn' in cf.model_path: writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch) train_results_list.append([results_dict['boxes'], batch['pid']])#just gt and det monitor_metrics['train']['monitor_values'][epoch].append(results_dict['monitor_values']) count_train = train_evaluator.evaluate_predictions(train_results_list,epoch,cf,flag = 'train') print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_train[0],count_train[1],count_train[2],count_train[3])) precision = count_train[0]/ (count_train[0]+count_train[2]+0.01) recall = count_train[0]/ (count_train[3]) print('precision:{}, recall:{}'.format(precision,recall)) monitor_metrics['train']['train_recall'].append(recall) monitor_metrics['train']['train_percision'].append(precision) writer.add_scalar('Train/train_precision',precision,epoch) writer.add_scalar('Train/train_recall',recall,epoch) train_time = time.time() - start_time print('*'*50 + 'finish epoch {}'.format(epoch)) logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') dice_val = [] for _ in range(batch_gen['n_val']):#50 num_val += 1 batch = next(batch_gen[cf.val_mode]) #print('valing',batch['pid']) for ii,i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) if (num_val) % cf.show_val_images == 0: fig = plot_batch_prediction(batch, results_dict, cf,'val') writer.add_figure('Val/results',fig,num_val) fig.clear() this_batch_seg_label = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], cf.num_seg_classes)).cuda() this_batch_dice = DiceLoss() dice = 1- this_batch_dice(F.softmax(results_dict['seg_logits'],dim=1),this_batch_seg_label) #this_batch_dice = batch_dice(F.softmax(results_dict['seg_logits'],dim = 1),this_batch_seg_label,showdice = True) dice_val.append(dice) val_results_list.append([results_dict['boxes'], batch['pid']]) monitor_metrics['val']['monitor_values'][epoch].append(results_dict['monitor_values']) count_val = val_evaluator.evaluate_predictions(val_results_list,epoch,cf,flag = 'val') print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_val[0],count_val[1],count_val[2],count_val[3])) precision = count_val[0]/ (count_val[0]+count_val[2]+0.01) recall = count_val[0]/ (count_val[3]) print('precision:{}, recall:{}'.format(precision,recall)) monitor_metrics['val']['val_recall'].append(recall) monitor_metrics['val']['val_percision'].append(precision) writer.add_scalar('Val/val_precision',precision,epoch) writer.add_scalar('Val/val_recall',recall,epoch) writer.add_scalar('Val/val_dice',sum(dice_val)/float(len(dice_val)),epoch) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots #TrainingPlot.update_and_save(monitor_metrics, epoch) epoch_time = time.time() - start_time logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format( epoch, epoch_time, train_time, epoch_time-train_time)) writer.close()