def test_transfer_learning(download_data_testing_test_files, path_model, fraction, tolerance=0.15): device = torch.device("cpu") logger.info(f"Working on {'cpu'}.") logger.info(__data_testing_dir__) # Load pretrained model model_pretrained = torch.load(path_model, map_location=device) # Setup model for retrain model_to_retrain = imed_models.set_model_for_retrain(path_model, retrain_fraction=fraction, map_location=device) logger.info(f"\nSet fraction to retrain: {fraction}") # Check Frozen part grad_list = [param.requires_grad for name, param in model_to_retrain.named_parameters()] fraction_retrain_measured = sum(grad_list) * 1.0 / len(grad_list) logger.debug(f"\nMeasure: retrained fraction of the model: {round(fraction_retrain_measured, 1)}") # for name, param in model.named_parameters(): # print("\t", name, param.requires_grad) assert (abs(fraction_retrain_measured - fraction) <= tolerance) total_params = sum(p.numel() for p in model_to_retrain.parameters()) logger.info(f"{total_params} total parameters.") total_trainable_params = sum( p.numel() for p in model_to_retrain.parameters() if p.requires_grad) logger.info(f"{total_trainable_params} parameters to retrain.") assert (total_params > total_trainable_params) # Check reset weights reset_list = [(p1.data.ne(p2.data).sum() > 0).cpu().numpy() for p1, p2 in zip(model_pretrained.parameters(), model_to_retrain.parameters())] reset_measured = sum(reset_list) * 1.0 / len(reset_list) logger.info(f"\nMeasure: reset fraction of the model: {round(reset_measured, 1)}") assert (abs(reset_measured - fraction) <= tolerance)
def test_transfer_learning(path_model, fraction, tolerance=0.15): device = torch.device("cpu") print("Working on {}.".format('cpu')) # Load pretrained model model_pretrained = torch.load(path_model, map_location=device) # Setup model for retrain model_to_retrain = imed_models.set_model_for_retrain( path_model, retrain_fraction=fraction, map_location=device) print('\nSet fraction to retrain: ' + str(fraction)) # Check Frozen part grad_list = [ param.requires_grad for name, param in model_to_retrain.named_parameters() ] fraction_retrain_measured = sum(grad_list) * 1.0 / len(grad_list) print('\nMeasure: retrained fraction of the model: ' + str(round(fraction_retrain_measured, 1))) # for name, param in model.named_parameters(): # print("\t", name, param.requires_grad) assert (abs(fraction_retrain_measured - fraction) <= tolerance) total_params = sum(p.numel() for p in model_to_retrain.parameters()) print('{:,} total parameters.'.format(total_params)) total_trainable_params = sum(p.numel() for p in model_to_retrain.parameters() if p.requires_grad) print('{:,} parameters to retrain.'.format(total_trainable_params)) assert (total_params > total_trainable_params) # Check reset weights reset_list = [(p1.data.ne(p2.data).sum() > 0).cpu().numpy() \ for p1, p2 in zip(model_pretrained.parameters(), model_to_retrain.parameters())] reset_measured = sum(reset_list) * 1.0 / len(reset_list) print('\nMeasure: reset fraction of the model: ' + str(round(reset_measured, 1))) assert (abs(reset_measured - fraction) <= tolerance)
def train(model_params, dataset_train, dataset_val, training_params, log_directory, device, cuda_available=True, metric_fns=None, n_gif=0, resume_training=False, debugging=False): """Main command to train the network. Args: model_params (dict): Model's parameters. dataset_train (imed_loader): Training dataset. dataset_val (imed_loader): Validation dataset. training_params (dict): log_directory (str): Folder where log files, best and final models are saved. device (str): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the log directory. resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume training. This training state is saved everytime a new best model is saved in the log directory. debugging (bool): If True, extended verbosity and intermediate outputs. Returns: float, float, float, float: best_training_dice, best_training_loss, best_validation_dice, best_validation_loss. """ # Write the metrics, images, etc to TensorBoard format writer = SummaryWriter(log_dir=log_directory) # BALANCE SAMPLES AND PYTORCH LOADER conditions = all([ training_params["balance_samples"]["applied"], model_params["name"] != "HeMIS" ]) sampler_train, shuffle_train = get_sampler( dataset_train, conditions, training_params['balance_samples']['type']) train_loader = DataLoader(dataset_train, batch_size=training_params["batch_size"], shuffle=shuffle_train, pin_memory=True, sampler=sampler_train, collate_fn=imed_loader_utils.imed_collate, num_workers=0) gif_dict = {"image_path": [], "slice_id": [], "gif": []} if dataset_val: sampler_val, shuffle_val = get_sampler( dataset_val, conditions, training_params['balance_samples']['type']) val_loader = DataLoader(dataset_val, batch_size=training_params["batch_size"], shuffle=shuffle_val, pin_memory=True, sampler=sampler_val, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Init GIF if n_gif > 0: indexes_gif = random.sample(range(len(dataset_val)), n_gif) for i_gif in range(n_gif): random_metadata = dict( dataset_val[indexes_gif[i_gif]]["input_metadata"][0]) gif_dict["image_path"].append(random_metadata['input_filenames']) gif_dict["slice_id"].append(random_metadata['slice_index']) gif_obj = imed_visualize.AnimatedGif( size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape) gif_dict["gif"].append(copy.copy(gif_obj)) # GET MODEL if training_params["transfer_learning"]["retrain_model"]: print("\nLoading pretrained model's weights: {}.") print("\tFreezing the {}% first layers.".format( 100 - training_params["transfer_learning"]['retrain_fraction'] * 100.)) old_model_path = training_params["transfer_learning"]["retrain_model"] fraction = training_params["transfer_learning"]['retrain_fraction'] if 'reset' in training_params["transfer_learning"]: reset = training_params["transfer_learning"]['reset'] else: reset = True # Freeze first layers and reset last layers model = imed_models.set_model_for_retrain(old_model_path, retrain_fraction=fraction, map_location=device, reset=reset) else: print("\nInitialising model's weights from scratch.") model_class = getattr(imed_models, model_params["name"]) model = model_class(**model_params) if cuda_available: model.cuda() num_epochs = training_params["training_time"]["num_epochs"] # OPTIMIZER initial_lr = training_params["scheduler"]["initial_lr"] # filter out the parameters you are going to fine-tuning params_to_opt = filter(lambda p: p.requires_grad, model.parameters()) # Using Adam optimizer = optim.Adam(params_to_opt, lr=initial_lr) scheduler, step_scheduler_batch = get_scheduler( copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer, num_epochs) print("\nScheduler parameters: {}".format( training_params["scheduler"]["lr_scheduler"])) # Create dict containing gammas and betas after each FiLM layer. if 'film_layers' in model_params and any(model_params['film_layers']): gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} contrast_list = [] # Resume start_epoch = 1 resume_path = os.path.join(log_directory, "checkpoint.pth.tar") if resume_training: model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint( model=model, optimizer=optimizer, gif_dict=gif_dict, scheduler=scheduler, fname=resume_path) # Individually transfer the optimizer parts # TODO: check if following lines are needed for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) # LOSS print("\nSelected Loss: {}".format(training_params["loss"]["name"])) print("\twith the parameters: {}".format([ training_params["loss"][k] for k in training_params["loss"] if k != "name" ])) loss_fct = get_loss_function(copy.copy(training_params["loss"])) loss_dice_fct = imed_losses.DiceLoss( ) # For comparison when another loss is used # INIT TRAINING VARIABLES best_training_dice, best_training_loss = float("inf"), float("inf") best_validation_loss, best_validation_dice = float("inf"), float("inf") patience_count = 0 begin_time = time.time() # EPOCH LOOP for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch): epoch = epoch + start_epoch start_time = time.time() lr = scheduler.get_last_lr()[0] writer.add_scalar('learning_rate', lr, epoch) # Training loop ----------------------------------------------------------- model.train() train_loss_total, train_dice_loss_total = 0.0, 0.0 num_steps = 0 for i, batch in enumerate(train_loader): # GET SAMPLES if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # MIXUP if training_params["mixup_alpha"]: input_samples, gt_samples = imed_mixup.mixup( input_samples, gt_samples, training_params["mixup_alpha"], debugging and epoch == 1, log_directory) # RUN MODEL if model_params["name"] == "HeMISUnet" or \ ('film_layers' in model_params and any(model_params['film_layers'])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) train_loss_total += loss.item() train_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # UPDATE OPTIMIZER optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 if i == 0 and debugging: imed_visualize.save_tensorboard_img( writer, epoch, "Train", input_samples, gt_samples, preds, is_three_dim=not model_params["is_2d"]) if not step_scheduler_batch: scheduler.step() # TRAINING LOSS train_loss_total_avg = train_loss_total / num_steps msg = "Epoch {} training loss: {:.4f}.".format(epoch, train_loss_total_avg) train_dice_loss_total_avg = train_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice training loss: {:.4f}.".format( train_dice_loss_total_avg) tqdm.write(msg) # CURRICULUM LEARNING if model_params["name"] == "HeMISUnet": # Increase the probability of a missing modality model_params["missing_probability"] **= model_params[ "missing_probability_growth"] dataset_train.update(p=model_params["missing_probability"]) # Validation loop ----------------------------------------------------- model.eval() val_loss_total, val_dice_loss_total = 0.0, 0.0 num_steps = 0 metric_mgr = imed_metrics.MetricManager(metric_fns) if dataset_val: for i, batch in enumerate(val_loader): with torch.no_grad(): # GET SAMPLES if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda( batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # RUN MODEL if model_params["name"] == "HeMISUnet" or \ ('film_layers' in model_params and any(model_params['film_layers'])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) val_loss_total += loss.item() val_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # Add frame to GIF for i_ in range(len(input_samples)): im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \ batch["input_metadata"][i_][0] for i_gif in range(n_gif): if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \ gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'): overlap = imed_visualize.overlap_im_seg(im, pr) gif_dict["gif"][i_gif].add(overlap, label=str(epoch)) num_steps += 1 # METRICS COMPUTATION gt_npy = gt_samples.cpu().numpy() preds_npy = preds.data.cpu().numpy() metric_mgr(preds_npy, gt_npy) if i == 0 and debugging: imed_visualize.save_tensorboard_img( writer, epoch, "Validation", input_samples, gt_samples, preds, is_three_dim=not model_params['is_2d']) if 'film_layers' in model_params and any(model_params['film_layers']) and debugging and \ epoch == num_epochs and i < int(len(dataset_val) / training_params["batch_size"]) + 1: # Store the values of gammas and betas after the last epoch for each batch gammas_dict, betas_dict, contrast_list = store_film_params( gammas_dict, betas_dict, contrast_list, batch['input_metadata'], model, model_params["film_layers"], model_params["depth"]) # METRICS COMPUTATION FOR CURRENT EPOCH val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None metrics_dict = metric_mgr.get_results() metric_mgr.reset() writer.add_scalars('Validation/Metrics', metrics_dict, epoch) val_loss_total_avg = val_loss_total / num_steps writer.add_scalars( 'losses', { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, }, epoch) msg = "Epoch {} validation loss: {:.4f}.".format( epoch, val_loss_total_avg) val_dice_loss_total_avg = val_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice validation loss: {:.4f}.".format( val_dice_loss_total_avg) tqdm.write(msg) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format( epoch, total_time)) # UPDATE BEST RESULTS if val_loss_total_avg < best_validation_loss: # Save checkpoint state = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'gif_dict': gif_dict, 'scheduler': scheduler, 'patience_count': patience_count, 'validation_loss': val_loss_total_avg } torch.save(state, resume_path) # Save best model file model_path = os.path.join(log_directory, "best_model.pt") torch.save(model, model_path) # Update best scores best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg # EARLY STOPPING if epoch > 1: val_diff = (val_loss_total_avg_old - val_loss_total_avg) * 100 / abs(val_loss_total_avg) if val_diff < training_params["training_time"][ "early_stopping_epsilon"]: patience_count += 1 if patience_count >= training_params["training_time"][ "early_stopping_patience"]: print( "Stopping training due to {} epochs without improvements" .format(patience_count)) break # Save final model final_model_path = os.path.join(log_directory, "final_model.pt") torch.save(model, final_model_path) if 'film_layers' in model_params and any( model_params['film_layers']) and debugging: save_film_params(gammas_dict, betas_dict, contrast_list, model_params["depth"], log_directory) # Save best model in log directory if os.path.isfile(resume_path): state = torch.load(resume_path) model_path = os.path.join(log_directory, "best_model.pt") model.load_state_dict(state['state_dict']) torch.save(model, model_path) # Save best model as ONNX in the model directory try: # Convert best model to ONNX and save it in model directory best_model_path = os.path.join( log_directory, model_params["folder_name"], model_params["folder_name"] + ".onnx") imed_utils.save_onnx_model(model, input_samples, best_model_path) except: # Save best model in model directory best_model_path = os.path.join(log_directory, model_params["folder_name"], model_params["folder_name"] + ".pt") torch.save(model, best_model_path) logger.warning( "Failed to save the model as '.onnx', saved it as '.pt': {}". format(best_model_path)) # Save GIFs gif_folder = os.path.join(log_directory, "gifs") if n_gif > 0 and not os.path.isdir(gif_folder): os.makedirs(gif_folder) for i_gif in range(n_gif): fname_out = gif_dict["image_path"][i_gif].split('/')[-3] + "__" fname_out += gif_dict["image_path"][i_gif].split('/')[-1].split( ".nii.gz")[0].split(gif_dict["image_path"][i_gif].split('/')[-3] + "_")[1] + "__" fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif" path_gif_out = os.path.join(gif_folder, fname_out) gif_dict["gif"][i_gif].save(path_gif_out) writer.close() final_time = time.time() duration_time = final_time - begin_time print('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) + "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) + "| duration " + str(datetime.timedelta(seconds=duration_time))) return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss
def train(model_params, dataset_train, dataset_val, training_params, path_output, device, wandb_params=None, cuda_available=True, metric_fns=None, n_gif=0, resume_training=False, debugging=False): """Main command to train the network. Args: model_params (dict): Model's parameters. dataset_train (imed_loader): Training dataset. dataset_val (imed_loader): Validation dataset. training_params (dict): path_output (str): Folder where log files, best and final models are saved. device (str): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the output path. resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the path_output) for resume training. This training state is saved everytime a new best model is saved in the log directory. debugging (bool): If True, extended verbosity and intermediate outputs. Returns: float, float, float, float: best_training_dice, best_training_loss, best_validation_dice, best_validation_loss. """ # Write the metrics, images, etc to TensorBoard format writer = SummaryWriter(log_dir=path_output) # Enable wandb tracking if the required params are found in the config file and the api key is correct wandb_tracking = imed_utils.initialize_wandb(wandb_params) if wandb_tracking: # Collect all hyperparameters into a dictionary cfg = {**training_params, **model_params} # Get the actual project, group, and run names if they exist, else choose the temporary names as default project_name = wandb_params.get(WandbKW.PROJECT_NAME, "temp_project") group_name = wandb_params.get(WandbKW.GROUP_NAME, "temp_group") run_name = wandb_params.get(WandbKW.RUN_NAME, "temp_run") if project_name == "temp_project" or group_name == "temp_group" or run_name == "temp_run": logger.info( "{PROJECT/GROUP/RUN} name not found, initializing as {'temp_project'/'temp_group'/'temp_run'}" ) # Initialize WandB with metrics and hyperparameters wandb.init(project=project_name, group=group_name, name=run_name, config=cfg) # BALANCE SAMPLES AND PYTORCH LOADER conditions = all([ training_params[TrainingParamsKW.BALANCE_SAMPLES] [BalanceSamplesKW.APPLIED], model_params[ModelParamsKW.NAME] != "HeMIS" ]) sampler_train, shuffle_train = get_sampler( dataset_train, conditions, training_params[ TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE]) train_loader = DataLoader( dataset_train, batch_size=training_params[TrainingParamsKW.BATCH_SIZE], shuffle=shuffle_train, pin_memory=True, sampler=sampler_train, collate_fn=imed_loader_utils.imed_collate, num_workers=0) gif_dict = {"image_path": [], "slice_id": [], "gif": []} if dataset_val: sampler_val, shuffle_val = get_sampler( dataset_val, conditions, training_params[ TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE]) val_loader = DataLoader( dataset_val, batch_size=training_params[TrainingParamsKW.BATCH_SIZE], shuffle=shuffle_val, pin_memory=True, sampler=sampler_val, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Init GIF if n_gif > 0: indexes_gif = random.sample(range(len(dataset_val)), n_gif) for i_gif in range(n_gif): random_metadata = dict( dataset_val[indexes_gif[i_gif]][MetadataKW.INPUT_METADATA][0]) gif_dict["image_path"].append( random_metadata[MetadataKW.INPUT_FILENAMES]) gif_dict["slice_id"].append( random_metadata[MetadataKW.SLICE_INDEX]) gif_obj = imed_visualize.AnimatedGif( size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape) gif_dict["gif"].append(copy.copy(gif_obj)) # GET MODEL if training_params["transfer_learning"]["retrain_model"]: logger.info("Loading pretrained model's weights: {}.") logger.info("\tFreezing the {}% first layers.".format( 100 - training_params["transfer_learning"]['retrain_fraction'] * 100.)) old_model_path = training_params["transfer_learning"]["retrain_model"] fraction = training_params["transfer_learning"]['retrain_fraction'] if 'reset' in training_params["transfer_learning"]: reset = training_params["transfer_learning"]['reset'] else: reset = True # Freeze first layers and reset last layers model = imed_models.set_model_for_retrain(old_model_path, retrain_fraction=fraction, map_location=device, reset=reset) else: logger.info("Initialising model's weights from scratch.") model_class = getattr(imed_models, model_params[ModelParamsKW.NAME]) model = model_class(**model_params) if cuda_available: model.cuda() num_epochs = training_params["training_time"]["num_epochs"] # OPTIMIZER initial_lr = training_params["scheduler"]["initial_lr"] # filter out the parameters you are going to fine-tuning params_to_opt = filter(lambda p: p.requires_grad, model.parameters()) # Using Adam optimizer = optim.Adam(params_to_opt, lr=initial_lr) scheduler, step_scheduler_batch = get_scheduler( copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer, num_epochs) logger.info("Scheduler parameters: {}".format( training_params["scheduler"]["lr_scheduler"])) # Only call wandb methods if required params are found in the config file if wandb_tracking: # Logs gradients (at every log_freq steps) to the dashboard. wandb.watch(model, log="gradients", log_freq=wandb_params["log_grads_every"]) # Resume start_epoch = 1 resume_path = Path(path_output, "checkpoint.pth.tar") if resume_training: model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint( model=model, optimizer=optimizer, gif_dict=gif_dict, scheduler=scheduler, fname=str(resume_path)) # Individually transfer the optimizer parts # TODO: check if following lines are needed for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) # LOSS logger.info("Selected Loss: {}".format(training_params["loss"]["name"])) logger.info("\twith the parameters: {}".format([ training_params["loss"][k] for k in training_params["loss"] if k != "name" ])) loss_fct = get_loss_function(copy.copy(training_params["loss"])) loss_dice_fct = imed_losses.DiceLoss( ) # For comparison when another loss is used # INIT TRAINING VARIABLES best_training_dice, best_training_loss = float("inf"), float("inf") best_validation_loss, best_validation_dice = float("inf"), float("inf") patience_count = 0 begin_time = time.time() # EPOCH LOOP for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch): epoch = epoch + start_epoch start_time = time.time() lr = scheduler.get_last_lr()[0] writer.add_scalar('learning_rate', lr, epoch) if wandb_tracking: wandb.log({"learning_rate": lr}) # Training loop ----------------------------------------------------------- model.train() train_loss_total, train_dice_loss_total = 0.0, 0.0 num_steps = 0 for i, batch in enumerate(train_loader): # GET SAMPLES if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # MIXUP if training_params["mixup_alpha"]: input_samples, gt_samples = imed_mixup.mixup( input_samples, gt_samples, training_params["mixup_alpha"], debugging and epoch == 1, path_output) # RUN MODEL if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \ (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])): metadata = get_metadata(batch[MetadataKW.INPUT_METADATA], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) train_loss_total += loss.item() train_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # UPDATE OPTIMIZER optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 # Save image at every 50th step if debugging is true if i % 50 == 0 and debugging: imed_visualize.save_img( writer, epoch, "Train", input_samples, gt_samples, preds, wandb_tracking=wandb_tracking, is_three_dim=not model_params[ModelParamsKW.IS_2D]) if not step_scheduler_batch: scheduler.step() # TRAINING LOSS train_loss_total_avg = train_loss_total / num_steps msg = "Epoch {} training loss: {:.4f}.".format(epoch, train_loss_total_avg) train_dice_loss_total_avg = train_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice training loss: {:.4f}.".format( train_dice_loss_total_avg) logger.info(msg) tqdm.write(msg) # CURRICULUM LEARNING if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: # Increase the probability of a missing modality model_params[ModelParamsKW.MISSING_PROBABILITY] **= model_params[ ModelParamsKW.MISSING_PROBABILITY_GROWTH] dataset_train.update( p=model_params[ModelParamsKW.MISSING_PROBABILITY]) # Validation loop ----------------------------------------------------- model.eval() val_loss_total, val_dice_loss_total = 0.0, 0.0 num_steps = 0 metric_mgr = imed_metrics.MetricManager(metric_fns) if dataset_val: for i, batch in enumerate(val_loader): with torch.no_grad(): # GET SAMPLES if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda( batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # RUN MODEL if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \ (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])): metadata = get_metadata( batch[MetadataKW.INPUT_METADATA], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) val_loss_total += loss.item() val_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # Add frame to GIF for i_ in range(len(input_samples)): im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \ batch[MetadataKW.INPUT_METADATA][i_][0] for i_gif in range(n_gif): if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \ gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'): overlap = imed_visualize.overlap_im_seg(im, pr) gif_dict["gif"][i_gif].add(overlap, label=str(epoch)) num_steps += 1 # METRICS COMPUTATION gt_npy = gt_samples.cpu().numpy() preds_npy = preds.data.cpu().numpy() metric_mgr(preds_npy, gt_npy) # Save image at every 10th step if debugging is true if i % 50 == 0 and debugging: imed_visualize.save_img( writer, epoch, "Validation", input_samples, gt_samples, preds, wandb_tracking=wandb_tracking, is_three_dim=not model_params[ModelParamsKW.IS_2D]) # METRICS COMPUTATION FOR CURRENT EPOCH val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None metrics_dict = metric_mgr.get_results() metric_mgr.reset() val_loss_total_avg = val_loss_total / num_steps # log losses on Tensorboard by default writer.add_scalars('Validation/Metrics', metrics_dict, epoch) writer.add_scalars( 'losses', { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, }, epoch) # log on wandb if the corresponding dictionary is provided if wandb_tracking: wandb.log({"validation-metrics": metrics_dict}) wandb.log({ "losses": { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, } }) msg = "Epoch {} validation loss: {:.4f}.".format( epoch, val_loss_total_avg) val_dice_loss_total_avg = val_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice validation loss: {:.4f}.".format( val_dice_loss_total_avg) logger.info(msg) end_time = time.time() total_time = end_time - start_time msg_epoch = "Epoch {} took {:.2f} seconds.".format( epoch, total_time) logger.info(msg_epoch) # UPDATE BEST RESULTS if val_loss_total_avg < best_validation_loss: # Save checkpoint state = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'gif_dict': gif_dict, 'scheduler': scheduler, 'patience_count': patience_count, 'validation_loss': val_loss_total_avg } torch.save(state, resume_path) # Save best model file model_path = Path(path_output, "best_model.pt") torch.save(model, model_path) # Update best scores best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg # EARLY STOPPING if epoch > 1: val_diff = (val_loss_total_avg_old - val_loss_total_avg) * 100 / abs(val_loss_total_avg) if val_diff < training_params["training_time"][ "early_stopping_epsilon"]: patience_count += 1 if patience_count >= training_params["training_time"][ "early_stopping_patience"]: logger.info( "Stopping training due to {} epochs without improvements" .format(patience_count)) break # Save final model final_model_path = Path(path_output, "final_model.pt") torch.save(model, final_model_path) # Save best model in output path if resume_path.is_file(): state = torch.load(resume_path) model_path = Path(path_output, "best_model.pt") model.load_state_dict(state['state_dict']) torch.save(model, model_path) # Save best model as ONNX in the model directory try: # Convert best model to ONNX and save it in model directory best_model_path = Path( path_output, model_params[ModelParamsKW.FOLDER_NAME], model_params[ModelParamsKW.FOLDER_NAME] + ".onnx") imed_utils.save_onnx_model(model, input_samples, str(best_model_path)) logger.info(f"Model saved as '.onnx': {best_model_path}") except Exception as e: logger.warning(f"Failed to save the model as '.onnx': {e}") # Save best model as PT in the model directory best_model_path = Path(path_output, model_params[ModelParamsKW.FOLDER_NAME], model_params[ModelParamsKW.FOLDER_NAME] + ".pt") torch.save(model, best_model_path) logger.info(f"Model saved as '.pt': {best_model_path}") # Save GIFs gif_folder = Path(path_output, "gifs") if n_gif > 0 and not gif_folder.is_dir(): gif_folder.mkdir(parents=True) for i_gif in range(n_gif): fname_out = gif_dict["image_path"][i_gif].split(os.sep)[-3] + "__" fname_out += gif_dict["image_path"][i_gif].split( os.sep)[-1].split(".nii.gz")[0].split( gif_dict["image_path"][i_gif].split(os.sep)[-3] + "_")[1] + "__" fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif" path_gif_out = Path(gif_folder, fname_out) gif_dict["gif"][i_gif].save(str(path_gif_out)) writer.close() wandb.finish() final_time = time.time() duration_time = final_time - begin_time logger.info('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) + "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) + "| duration " + str(datetime.timedelta(seconds=duration_time))) return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss