def train(model_params, dataset_train, dataset_val, training_params, log_directory, device, cuda_available=True, metric_fns=None, n_gif=0, resume_training=False, debugging=False): """Main command to train the network. Args: model_params (dict): Model's parameters. dataset_train (imed_loader): Training dataset. dataset_val (imed_loader): Validation dataset. training_params (dict): log_directory (str): Folder where log files, best and final models are saved. device (str): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the log directory. resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume training. This training state is saved everytime a new best model is saved in the log directory. debugging (bool): If True, extended verbosity and intermediate outputs. Returns: float, float, float, float: best_training_dice, best_training_loss, best_validation_dice, best_validation_loss. """ # Write the metrics, images, etc to TensorBoard format writer = SummaryWriter(log_dir=log_directory) # BALANCE SAMPLES AND PYTORCH LOADER conditions = all([ training_params["balance_samples"]["applied"], model_params["name"] != "HeMIS" ]) sampler_train, shuffle_train = get_sampler( dataset_train, conditions, training_params['balance_samples']['type']) train_loader = DataLoader(dataset_train, batch_size=training_params["batch_size"], shuffle=shuffle_train, pin_memory=True, sampler=sampler_train, collate_fn=imed_loader_utils.imed_collate, num_workers=0) gif_dict = {"image_path": [], "slice_id": [], "gif": []} if dataset_val: sampler_val, shuffle_val = get_sampler( dataset_val, conditions, training_params['balance_samples']['type']) val_loader = DataLoader(dataset_val, batch_size=training_params["batch_size"], shuffle=shuffle_val, pin_memory=True, sampler=sampler_val, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Init GIF if n_gif > 0: indexes_gif = random.sample(range(len(dataset_val)), n_gif) for i_gif in range(n_gif): random_metadata = dict( dataset_val[indexes_gif[i_gif]]["input_metadata"][0]) gif_dict["image_path"].append(random_metadata['input_filenames']) gif_dict["slice_id"].append(random_metadata['slice_index']) gif_obj = imed_visualize.AnimatedGif( size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape) gif_dict["gif"].append(copy.copy(gif_obj)) # GET MODEL if training_params["transfer_learning"]["retrain_model"]: print("\nLoading pretrained model's weights: {}.") print("\tFreezing the {}% first layers.".format( 100 - training_params["transfer_learning"]['retrain_fraction'] * 100.)) old_model_path = training_params["transfer_learning"]["retrain_model"] fraction = training_params["transfer_learning"]['retrain_fraction'] if 'reset' in training_params["transfer_learning"]: reset = training_params["transfer_learning"]['reset'] else: reset = True # Freeze first layers and reset last layers model = imed_models.set_model_for_retrain(old_model_path, retrain_fraction=fraction, map_location=device, reset=reset) else: print("\nInitialising model's weights from scratch.") model_class = getattr(imed_models, model_params["name"]) model = model_class(**model_params) if cuda_available: model.cuda() num_epochs = training_params["training_time"]["num_epochs"] # OPTIMIZER initial_lr = training_params["scheduler"]["initial_lr"] # filter out the parameters you are going to fine-tuning params_to_opt = filter(lambda p: p.requires_grad, model.parameters()) # Using Adam optimizer = optim.Adam(params_to_opt, lr=initial_lr) scheduler, step_scheduler_batch = get_scheduler( copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer, num_epochs) print("\nScheduler parameters: {}".format( training_params["scheduler"]["lr_scheduler"])) # Create dict containing gammas and betas after each FiLM layer. if 'film_layers' in model_params and any(model_params['film_layers']): gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} contrast_list = [] # Resume start_epoch = 1 resume_path = os.path.join(log_directory, "checkpoint.pth.tar") if resume_training: model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint( model=model, optimizer=optimizer, gif_dict=gif_dict, scheduler=scheduler, fname=resume_path) # Individually transfer the optimizer parts # TODO: check if following lines are needed for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) # LOSS print("\nSelected Loss: {}".format(training_params["loss"]["name"])) print("\twith the parameters: {}".format([ training_params["loss"][k] for k in training_params["loss"] if k != "name" ])) loss_fct = get_loss_function(copy.copy(training_params["loss"])) loss_dice_fct = imed_losses.DiceLoss( ) # For comparison when another loss is used # INIT TRAINING VARIABLES best_training_dice, best_training_loss = float("inf"), float("inf") best_validation_loss, best_validation_dice = float("inf"), float("inf") patience_count = 0 begin_time = time.time() # EPOCH LOOP for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch): epoch = epoch + start_epoch start_time = time.time() lr = scheduler.get_last_lr()[0] writer.add_scalar('learning_rate', lr, epoch) # Training loop ----------------------------------------------------------- model.train() train_loss_total, train_dice_loss_total = 0.0, 0.0 num_steps = 0 for i, batch in enumerate(train_loader): # GET SAMPLES if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # MIXUP if training_params["mixup_alpha"]: input_samples, gt_samples = imed_mixup.mixup( input_samples, gt_samples, training_params["mixup_alpha"], debugging and epoch == 1, log_directory) # RUN MODEL if model_params["name"] == "HeMISUnet" or \ ('film_layers' in model_params and any(model_params['film_layers'])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) train_loss_total += loss.item() train_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # UPDATE OPTIMIZER optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 if i == 0 and debugging: imed_visualize.save_tensorboard_img( writer, epoch, "Train", input_samples, gt_samples, preds, is_three_dim=not model_params["is_2d"]) if not step_scheduler_batch: scheduler.step() # TRAINING LOSS train_loss_total_avg = train_loss_total / num_steps msg = "Epoch {} training loss: {:.4f}.".format(epoch, train_loss_total_avg) train_dice_loss_total_avg = train_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice training loss: {:.4f}.".format( train_dice_loss_total_avg) tqdm.write(msg) # CURRICULUM LEARNING if model_params["name"] == "HeMISUnet": # Increase the probability of a missing modality model_params["missing_probability"] **= model_params[ "missing_probability_growth"] dataset_train.update(p=model_params["missing_probability"]) # Validation loop ----------------------------------------------------- model.eval() val_loss_total, val_dice_loss_total = 0.0, 0.0 num_steps = 0 metric_mgr = imed_metrics.MetricManager(metric_fns) if dataset_val: for i, batch in enumerate(val_loader): with torch.no_grad(): # GET SAMPLES if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda( batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # RUN MODEL if model_params["name"] == "HeMISUnet" or \ ('film_layers' in model_params and any(model_params['film_layers'])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) val_loss_total += loss.item() val_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # Add frame to GIF for i_ in range(len(input_samples)): im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \ batch["input_metadata"][i_][0] for i_gif in range(n_gif): if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \ gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'): overlap = imed_visualize.overlap_im_seg(im, pr) gif_dict["gif"][i_gif].add(overlap, label=str(epoch)) num_steps += 1 # METRICS COMPUTATION gt_npy = gt_samples.cpu().numpy() preds_npy = preds.data.cpu().numpy() metric_mgr(preds_npy, gt_npy) if i == 0 and debugging: imed_visualize.save_tensorboard_img( writer, epoch, "Validation", input_samples, gt_samples, preds, is_three_dim=not model_params['is_2d']) if 'film_layers' in model_params and any(model_params['film_layers']) and debugging and \ epoch == num_epochs and i < int(len(dataset_val) / training_params["batch_size"]) + 1: # Store the values of gammas and betas after the last epoch for each batch gammas_dict, betas_dict, contrast_list = store_film_params( gammas_dict, betas_dict, contrast_list, batch['input_metadata'], model, model_params["film_layers"], model_params["depth"]) # METRICS COMPUTATION FOR CURRENT EPOCH val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None metrics_dict = metric_mgr.get_results() metric_mgr.reset() writer.add_scalars('Validation/Metrics', metrics_dict, epoch) val_loss_total_avg = val_loss_total / num_steps writer.add_scalars( 'losses', { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, }, epoch) msg = "Epoch {} validation loss: {:.4f}.".format( epoch, val_loss_total_avg) val_dice_loss_total_avg = val_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice validation loss: {:.4f}.".format( val_dice_loss_total_avg) tqdm.write(msg) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format( epoch, total_time)) # UPDATE BEST RESULTS if val_loss_total_avg < best_validation_loss: # Save checkpoint state = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'gif_dict': gif_dict, 'scheduler': scheduler, 'patience_count': patience_count, 'validation_loss': val_loss_total_avg } torch.save(state, resume_path) # Save best model file model_path = os.path.join(log_directory, "best_model.pt") torch.save(model, model_path) # Update best scores best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg # EARLY STOPPING if epoch > 1: val_diff = (val_loss_total_avg_old - val_loss_total_avg) * 100 / abs(val_loss_total_avg) if val_diff < training_params["training_time"][ "early_stopping_epsilon"]: patience_count += 1 if patience_count >= training_params["training_time"][ "early_stopping_patience"]: print( "Stopping training due to {} epochs without improvements" .format(patience_count)) break # Save final model final_model_path = os.path.join(log_directory, "final_model.pt") torch.save(model, final_model_path) if 'film_layers' in model_params and any( model_params['film_layers']) and debugging: save_film_params(gammas_dict, betas_dict, contrast_list, model_params["depth"], log_directory) # Save best model in log directory if os.path.isfile(resume_path): state = torch.load(resume_path) model_path = os.path.join(log_directory, "best_model.pt") model.load_state_dict(state['state_dict']) torch.save(model, model_path) # Save best model as ONNX in the model directory try: # Convert best model to ONNX and save it in model directory best_model_path = os.path.join( log_directory, model_params["folder_name"], model_params["folder_name"] + ".onnx") imed_utils.save_onnx_model(model, input_samples, best_model_path) except: # Save best model in model directory best_model_path = os.path.join(log_directory, model_params["folder_name"], model_params["folder_name"] + ".pt") torch.save(model, best_model_path) logger.warning( "Failed to save the model as '.onnx', saved it as '.pt': {}". format(best_model_path)) # Save GIFs gif_folder = os.path.join(log_directory, "gifs") if n_gif > 0 and not os.path.isdir(gif_folder): os.makedirs(gif_folder) for i_gif in range(n_gif): fname_out = gif_dict["image_path"][i_gif].split('/')[-3] + "__" fname_out += gif_dict["image_path"][i_gif].split('/')[-1].split( ".nii.gz")[0].split(gif_dict["image_path"][i_gif].split('/')[-3] + "_")[1] + "__" fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif" path_gif_out = os.path.join(gif_folder, fname_out) gif_dict["gif"][i_gif].save(path_gif_out) writer.close() final_time = time.time() duration_time = final_time - begin_time print('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) + "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) + "| duration " + str(datetime.timedelta(seconds=duration_time))) return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss
def threshold_analysis(model_path, ds_lst, model_params, testing_params, metric="dice", increment=0.1, fname_out="thr.png", cuda_available=True): """Run a threshold analysis to find the optimal threshold on a sub-dataset. Args: model_path (str): Model path. ds_lst (list): List of loaders. model_params (dict): Model's parameters. testing_params (dict): Testing parameters metric (str): Choice between "dice" and "recall_specificity". If "recall_specificity", then a ROC analysis is performed. increment (float): Increment between tested thresholds. fname_out (str): Plot output filename. cuda_available (bool): If True, CUDA is available. Returns: float: optimal threshold. """ if metric not in ["dice", "recall_specificity"]: raise ValueError( '\nChoice of metric for threshold analysis: dice, recall_specificity.' ) # Adjust some testing parameters testing_params["uncertainty"]["applied"] = False # Load model model = torch.load(model_path) # Eval mode model.eval() # List of thresholds thr_list = list(np.arange(0.0, 1.0, increment))[1:] # Init metric manager for each thr metric_fns = [ imed_metrics.recall_score, imed_metrics.dice_score, imed_metrics.specificity_score ] metric_dict = { thr: imed_metrics.MetricManager(metric_fns) for thr in thr_list } # Load loader = DataLoader(ConcatDataset(ds_lst), batch_size=testing_params["batch_size"], shuffle=False, pin_memory=True, sampler=None, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Run inference preds_npy, gt_npy = run_inference(loader, model, model_params, testing_params, ofolder=None, cuda_available=cuda_available) print('\nRunning threshold analysis to find optimal threshold') # Make sure the GT is binarized gt_npy = [threshold_predictions(gt, thr=0.5) for gt in gt_npy] # Move threshold for thr in tqdm(thr_list, desc="Search"): preds_thr = [ threshold_predictions(copy.deepcopy(pred), thr=thr) for pred in preds_npy ] metric_dict[thr](preds_thr, gt_npy) # Get results tpr_list, fpr_list, dice_list = [], [], [] for thr in thr_list: result_thr = metric_dict[thr].get_results() tpr_list.append(result_thr["recall_score"]) fpr_list.append(1 - result_thr["specificity_score"]) dice_list.append(result_thr["dice_score"]) # Get optimal threshold if metric == "dice": diff_list = dice_list else: diff_list = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)] optimal_idx = np.max(np.where(diff_list == np.max(diff_list))) optimal_threshold = thr_list[optimal_idx] print('\tOptimal threshold: {}'.format(optimal_threshold)) # Save plot print('\tSaving plot: {}'.format(fname_out)) if metric == "dice": # Run plot imed_metrics.plot_dice_thr(thr_list, dice_list, optimal_idx, fname_out) else: # Add 0 and 1 as extrema tpr_list = [0.0] + tpr_list + [1.0] fpr_list = [0.0] + fpr_list + [1.0] optimal_idx += 1 # Run plot imed_metrics.plot_roc_curve(tpr_list, fpr_list, optimal_idx, fname_out) return optimal_threshold
def test_inference(transforms_dict, test_lst, target_lst, roi_params, testing_params): cuda_available, device = imed_utils.define_device(GPU_ID) model_params = {"name": "Unet", "is_2d": True} loader_params = { "transforms_params": transforms_dict, "data_list": test_lst, "dataset_type": "testing", "requires_undo": True, "contrast_params": {"contrast_lst": ['T2w'], "balance": {}}, "path_data": [__data_testing_dir__], "target_suffix": target_lst, "roi_params": roi_params, "slice_filter_params": { "filter_empty_mask": False, "filter_empty_input": True }, "slice_axis": SLICE_AXIS, "multichannel": False } loader_params.update({"model_params": model_params}) # Get Testing dataset ds_test = imed_loader.load_dataset(**loader_params) test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Undo transform val_undo_transform = imed_transforms.UndoCompose(imed_transforms.Compose(transforms_dict)) # Update testing_params testing_params.update({ "slice_axis": loader_params["slice_axis"], "target_suffix": loader_params["target_suffix"], "undo_transforms": val_undo_transform }) # Model model = imed_models.Unet() if cuda_available: model.cuda() model.eval() metric_fns = [imed_metrics.dice_score, imed_metrics.hausdorff_score, imed_metrics.precision_score, imed_metrics.recall_score, imed_metrics.specificity_score, imed_metrics.intersection_over_union, imed_metrics.accuracy_score] metric_mgr = imed_metrics.MetricManager(metric_fns) if not os.path.isdir(__output_dir__): os.makedirs(__output_dir__) preds_npy, gt_npy = imed_testing.run_inference(test_loader=test_loader, model=model, model_params=model_params, testing_params=testing_params, ofolder=__output_dir__, cuda_available=cuda_available) metric_mgr(preds_npy, gt_npy) metrics_dict = metric_mgr.get_results() metric_mgr.reset() print(metrics_dict)
def test(model_params, dataset_test, testing_params, path_output, device, cuda_available=True, metric_fns=None, postprocessing=None): """Main command to test the network. Args: model_params (dict): Model's parameters. dataset_test (imed_loader): Testing dataset. testing_params (dict): Testing parameters. path_output (str): Folder where predictions are saved. device (torch.device): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. postprocessing (dict): Contains postprocessing steps. Returns: dict: result metrics. """ # DATA LOADER test_loader = DataLoader(dataset_test, batch_size=testing_params["batch_size"], shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # LOAD TRAIN MODEL fname_model = os.path.join(path_output, "best_model.pt") print('\nLoading model: {}'.format(fname_model)) model = torch.load(fname_model, map_location=device) if cuda_available: model.cuda() model.eval() # CREATE OUTPUT FOLDER path_3Dpred = os.path.join(path_output, 'pred_masks') if not os.path.isdir(path_3Dpred): os.makedirs(path_3Dpred) # METRIC MANAGER metric_mgr = imed_metrics.MetricManager(metric_fns) # UNCERTAINTY SETTINGS if (testing_params['uncertainty']['epistemic'] or testing_params['uncertainty']['aleatoric']) and \ testing_params['uncertainty']['n_it'] > 0: n_monteCarlo = testing_params['uncertainty']['n_it'] + 1 testing_params['uncertainty']['applied'] = True print('\nComputing model uncertainty over {} iterations.'.format( n_monteCarlo - 1)) else: testing_params['uncertainty']['applied'] = False n_monteCarlo = 1 for i_monteCarlo in range(n_monteCarlo): preds_npy, gt_npy = run_inference(test_loader, model, model_params, testing_params, path_3Dpred, cuda_available, i_monteCarlo, postprocessing) metric_mgr(preds_npy, gt_npy) # If uncertainty computation, don't apply it on last iteration for prediction if testing_params['uncertainty']['applied'] and (n_monteCarlo - 2 == i_monteCarlo): testing_params['uncertainty']['applied'] = False # COMPUTE UNCERTAINTY MAPS imed_uncertainty.run_uncertainty(ifolder=path_3Dpred) metrics_dict = metric_mgr.get_results() metric_mgr.reset() print(metrics_dict) return metrics_dict
def train(model_params, dataset_train, dataset_val, training_params, path_output, device, wandb_params=None, cuda_available=True, metric_fns=None, n_gif=0, resume_training=False, debugging=False): """Main command to train the network. Args: model_params (dict): Model's parameters. dataset_train (imed_loader): Training dataset. dataset_val (imed_loader): Validation dataset. training_params (dict): path_output (str): Folder where log files, best and final models are saved. device (str): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the output path. resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the path_output) for resume training. This training state is saved everytime a new best model is saved in the log directory. debugging (bool): If True, extended verbosity and intermediate outputs. Returns: float, float, float, float: best_training_dice, best_training_loss, best_validation_dice, best_validation_loss. """ # Write the metrics, images, etc to TensorBoard format writer = SummaryWriter(log_dir=path_output) # Enable wandb tracking if the required params are found in the config file and the api key is correct wandb_tracking = imed_utils.initialize_wandb(wandb_params) if wandb_tracking: # Collect all hyperparameters into a dictionary cfg = {**training_params, **model_params} # Get the actual project, group, and run names if they exist, else choose the temporary names as default project_name = wandb_params.get(WandbKW.PROJECT_NAME, "temp_project") group_name = wandb_params.get(WandbKW.GROUP_NAME, "temp_group") run_name = wandb_params.get(WandbKW.RUN_NAME, "temp_run") if project_name == "temp_project" or group_name == "temp_group" or run_name == "temp_run": logger.info( "{PROJECT/GROUP/RUN} name not found, initializing as {'temp_project'/'temp_group'/'temp_run'}" ) # Initialize WandB with metrics and hyperparameters wandb.init(project=project_name, group=group_name, name=run_name, config=cfg) # BALANCE SAMPLES AND PYTORCH LOADER conditions = all([ training_params[TrainingParamsKW.BALANCE_SAMPLES] [BalanceSamplesKW.APPLIED], model_params[ModelParamsKW.NAME] != "HeMIS" ]) sampler_train, shuffle_train = get_sampler( dataset_train, conditions, training_params[ TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE]) train_loader = DataLoader( dataset_train, batch_size=training_params[TrainingParamsKW.BATCH_SIZE], shuffle=shuffle_train, pin_memory=True, sampler=sampler_train, collate_fn=imed_loader_utils.imed_collate, num_workers=0) gif_dict = {"image_path": [], "slice_id": [], "gif": []} if dataset_val: sampler_val, shuffle_val = get_sampler( dataset_val, conditions, training_params[ TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE]) val_loader = DataLoader( dataset_val, batch_size=training_params[TrainingParamsKW.BATCH_SIZE], shuffle=shuffle_val, pin_memory=True, sampler=sampler_val, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Init GIF if n_gif > 0: indexes_gif = random.sample(range(len(dataset_val)), n_gif) for i_gif in range(n_gif): random_metadata = dict( dataset_val[indexes_gif[i_gif]][MetadataKW.INPUT_METADATA][0]) gif_dict["image_path"].append( random_metadata[MetadataKW.INPUT_FILENAMES]) gif_dict["slice_id"].append( random_metadata[MetadataKW.SLICE_INDEX]) gif_obj = imed_visualize.AnimatedGif( size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape) gif_dict["gif"].append(copy.copy(gif_obj)) # GET MODEL if training_params["transfer_learning"]["retrain_model"]: logger.info("Loading pretrained model's weights: {}.") logger.info("\tFreezing the {}% first layers.".format( 100 - training_params["transfer_learning"]['retrain_fraction'] * 100.)) old_model_path = training_params["transfer_learning"]["retrain_model"] fraction = training_params["transfer_learning"]['retrain_fraction'] if 'reset' in training_params["transfer_learning"]: reset = training_params["transfer_learning"]['reset'] else: reset = True # Freeze first layers and reset last layers model = imed_models.set_model_for_retrain(old_model_path, retrain_fraction=fraction, map_location=device, reset=reset) else: logger.info("Initialising model's weights from scratch.") model_class = getattr(imed_models, model_params[ModelParamsKW.NAME]) model = model_class(**model_params) if cuda_available: model.cuda() num_epochs = training_params["training_time"]["num_epochs"] # OPTIMIZER initial_lr = training_params["scheduler"]["initial_lr"] # filter out the parameters you are going to fine-tuning params_to_opt = filter(lambda p: p.requires_grad, model.parameters()) # Using Adam optimizer = optim.Adam(params_to_opt, lr=initial_lr) scheduler, step_scheduler_batch = get_scheduler( copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer, num_epochs) logger.info("Scheduler parameters: {}".format( training_params["scheduler"]["lr_scheduler"])) # Only call wandb methods if required params are found in the config file if wandb_tracking: # Logs gradients (at every log_freq steps) to the dashboard. wandb.watch(model, log="gradients", log_freq=wandb_params["log_grads_every"]) # Resume start_epoch = 1 resume_path = Path(path_output, "checkpoint.pth.tar") if resume_training: model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint( model=model, optimizer=optimizer, gif_dict=gif_dict, scheduler=scheduler, fname=str(resume_path)) # Individually transfer the optimizer parts # TODO: check if following lines are needed for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) # LOSS logger.info("Selected Loss: {}".format(training_params["loss"]["name"])) logger.info("\twith the parameters: {}".format([ training_params["loss"][k] for k in training_params["loss"] if k != "name" ])) loss_fct = get_loss_function(copy.copy(training_params["loss"])) loss_dice_fct = imed_losses.DiceLoss( ) # For comparison when another loss is used # INIT TRAINING VARIABLES best_training_dice, best_training_loss = float("inf"), float("inf") best_validation_loss, best_validation_dice = float("inf"), float("inf") patience_count = 0 begin_time = time.time() # EPOCH LOOP for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch): epoch = epoch + start_epoch start_time = time.time() lr = scheduler.get_last_lr()[0] writer.add_scalar('learning_rate', lr, epoch) if wandb_tracking: wandb.log({"learning_rate": lr}) # Training loop ----------------------------------------------------------- model.train() train_loss_total, train_dice_loss_total = 0.0, 0.0 num_steps = 0 for i, batch in enumerate(train_loader): # GET SAMPLES if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # MIXUP if training_params["mixup_alpha"]: input_samples, gt_samples = imed_mixup.mixup( input_samples, gt_samples, training_params["mixup_alpha"], debugging and epoch == 1, path_output) # RUN MODEL if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \ (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])): metadata = get_metadata(batch[MetadataKW.INPUT_METADATA], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) train_loss_total += loss.item() train_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # UPDATE OPTIMIZER optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 # Save image at every 50th step if debugging is true if i % 50 == 0 and debugging: imed_visualize.save_img( writer, epoch, "Train", input_samples, gt_samples, preds, wandb_tracking=wandb_tracking, is_three_dim=not model_params[ModelParamsKW.IS_2D]) if not step_scheduler_batch: scheduler.step() # TRAINING LOSS train_loss_total_avg = train_loss_total / num_steps msg = "Epoch {} training loss: {:.4f}.".format(epoch, train_loss_total_avg) train_dice_loss_total_avg = train_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice training loss: {:.4f}.".format( train_dice_loss_total_avg) logger.info(msg) tqdm.write(msg) # CURRICULUM LEARNING if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: # Increase the probability of a missing modality model_params[ModelParamsKW.MISSING_PROBABILITY] **= model_params[ ModelParamsKW.MISSING_PROBABILITY_GROWTH] dataset_train.update( p=model_params[ModelParamsKW.MISSING_PROBABILITY]) # Validation loop ----------------------------------------------------- model.eval() val_loss_total, val_dice_loss_total = 0.0, 0.0 num_steps = 0 metric_mgr = imed_metrics.MetricManager(metric_fns) if dataset_val: for i, batch in enumerate(val_loader): with torch.no_grad(): # GET SAMPLES if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda( batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # RUN MODEL if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \ (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])): metadata = get_metadata( batch[MetadataKW.INPUT_METADATA], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) val_loss_total += loss.item() val_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # Add frame to GIF for i_ in range(len(input_samples)): im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \ batch[MetadataKW.INPUT_METADATA][i_][0] for i_gif in range(n_gif): if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \ gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'): overlap = imed_visualize.overlap_im_seg(im, pr) gif_dict["gif"][i_gif].add(overlap, label=str(epoch)) num_steps += 1 # METRICS COMPUTATION gt_npy = gt_samples.cpu().numpy() preds_npy = preds.data.cpu().numpy() metric_mgr(preds_npy, gt_npy) # Save image at every 10th step if debugging is true if i % 50 == 0 and debugging: imed_visualize.save_img( writer, epoch, "Validation", input_samples, gt_samples, preds, wandb_tracking=wandb_tracking, is_three_dim=not model_params[ModelParamsKW.IS_2D]) # METRICS COMPUTATION FOR CURRENT EPOCH val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None metrics_dict = metric_mgr.get_results() metric_mgr.reset() val_loss_total_avg = val_loss_total / num_steps # log losses on Tensorboard by default writer.add_scalars('Validation/Metrics', metrics_dict, epoch) writer.add_scalars( 'losses', { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, }, epoch) # log on wandb if the corresponding dictionary is provided if wandb_tracking: wandb.log({"validation-metrics": metrics_dict}) wandb.log({ "losses": { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, } }) msg = "Epoch {} validation loss: {:.4f}.".format( epoch, val_loss_total_avg) val_dice_loss_total_avg = val_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice validation loss: {:.4f}.".format( val_dice_loss_total_avg) logger.info(msg) end_time = time.time() total_time = end_time - start_time msg_epoch = "Epoch {} took {:.2f} seconds.".format( epoch, total_time) logger.info(msg_epoch) # UPDATE BEST RESULTS if val_loss_total_avg < best_validation_loss: # Save checkpoint state = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'gif_dict': gif_dict, 'scheduler': scheduler, 'patience_count': patience_count, 'validation_loss': val_loss_total_avg } torch.save(state, resume_path) # Save best model file model_path = Path(path_output, "best_model.pt") torch.save(model, model_path) # Update best scores best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg # EARLY STOPPING if epoch > 1: val_diff = (val_loss_total_avg_old - val_loss_total_avg) * 100 / abs(val_loss_total_avg) if val_diff < training_params["training_time"][ "early_stopping_epsilon"]: patience_count += 1 if patience_count >= training_params["training_time"][ "early_stopping_patience"]: logger.info( "Stopping training due to {} epochs without improvements" .format(patience_count)) break # Save final model final_model_path = Path(path_output, "final_model.pt") torch.save(model, final_model_path) # Save best model in output path if resume_path.is_file(): state = torch.load(resume_path) model_path = Path(path_output, "best_model.pt") model.load_state_dict(state['state_dict']) torch.save(model, model_path) # Save best model as ONNX in the model directory try: # Convert best model to ONNX and save it in model directory best_model_path = Path( path_output, model_params[ModelParamsKW.FOLDER_NAME], model_params[ModelParamsKW.FOLDER_NAME] + ".onnx") imed_utils.save_onnx_model(model, input_samples, str(best_model_path)) logger.info(f"Model saved as '.onnx': {best_model_path}") except Exception as e: logger.warning(f"Failed to save the model as '.onnx': {e}") # Save best model as PT in the model directory best_model_path = Path(path_output, model_params[ModelParamsKW.FOLDER_NAME], model_params[ModelParamsKW.FOLDER_NAME] + ".pt") torch.save(model, best_model_path) logger.info(f"Model saved as '.pt': {best_model_path}") # Save GIFs gif_folder = Path(path_output, "gifs") if n_gif > 0 and not gif_folder.is_dir(): gif_folder.mkdir(parents=True) for i_gif in range(n_gif): fname_out = gif_dict["image_path"][i_gif].split(os.sep)[-3] + "__" fname_out += gif_dict["image_path"][i_gif].split( os.sep)[-1].split(".nii.gz")[0].split( gif_dict["image_path"][i_gif].split(os.sep)[-3] + "_")[1] + "__" fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif" path_gif_out = Path(gif_folder, fname_out) gif_dict["gif"][i_gif].save(str(path_gif_out)) writer.close() wandb.finish() final_time = time.time() duration_time = final_time - begin_time logger.info('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) + "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) + "| duration " + str(datetime.timedelta(seconds=duration_time))) return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss