Esempio n. 1
0
def test_onnx():
    model = imed_models.UNet3D(1, 1)
    if not os.path.exists(PATH_MODEL):
        os.mkdir(PATH_MODEL)
    torch.save(model, PATH_MODEL_PT)
    img = nib.load(IMAGE_PATH).get_fdata().astype('float32')[:16, :64, :32]
    # Add batch and channel dimensions
    img_tensor = torch.tensor(img).unsqueeze(0).unsqueeze(0)
    dummy_input = torch.randn(1, 1, 32, 32, 32)
    imed_utils.save_onnx_model(model, dummy_input, PATH_MODEL_ONNX)

    model = torch.load(PATH_MODEL_PT)
    model.eval()
    out_pt = model(img_tensor).detach().numpy()

    out_onnx = imed_utils.onnx_inference(PATH_MODEL_ONNX, img_tensor).numpy()
    shutil.rmtree(PATH_MODEL)
    assert np.allclose(out_pt, out_onnx, rtol=1e-3)
Esempio n. 2
0
def convert_pytorch_to_onnx(model, dimension, n_channels, gpu_id=0):
    """Convert PyTorch model to ONNX.

    The integration of Deep Learning models into the clinical routine requires cpu optimized models. To export the
    PyTorch models to `ONNX <https://github.com/onnx/onnx>`_ format and to run the inference using
    `ONNX Runtime <https://github.com/microsoft/onnxruntime>`_ is a time and memory efficient way to answer this need.

    This function converts a model from PyTorch to ONNX format, with information of whether it is a 2D or 3D model
    (``-d``).

    Args:
        model (string): Model filename. Flag: ``--model``, ``-m``.
        dimension (int): Indicates whether the model is 2D or 3D. Choice between 2 or 3. Flag: ``--dimension``, ``-d``
        gpu_id (string): GPU ID, if available. Flag: ``--gpu_id``, ``-g``
    """
    if torch.cuda.is_available():
        device = "cuda:" + str(gpu_id)
    else:
        device = "cpu"

    model_net = torch.load(model, map_location=device)
    dummy_input = torch.randn(1, n_channels, 96, 96, device=device) if dimension == 2 \
                  else torch.randn(1, n_channels, 96, 96, 96, device=device)
    imed_utils.save_onnx_model(model_net, dummy_input, model.replace("pt", "onnx"))
Esempio n. 3
0
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          log_directory,
          device,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        log_directory (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the log directory.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=log_directory)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params["balance_samples"]["applied"],
        model_params["name"] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params['balance_samples']['type'])

    train_loader = DataLoader(dataset_train,
                              batch_size=training_params["batch_size"],
                              shuffle=shuffle_train,
                              pin_memory=True,
                              sampler=sampler_train,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions,
            training_params['balance_samples']['type'])

        val_loader = DataLoader(dataset_val,
                                batch_size=training_params["batch_size"],
                                shuffle=shuffle_val,
                                pin_memory=True,
                                sampler=sampler_val,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]]["input_metadata"][0])
            gif_dict["image_path"].append(random_metadata['input_filenames'])
            gif_dict["slice_id"].append(random_metadata['slice_index'])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        print("\nLoading pretrained model's weights: {}.")
        print("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        print("\nInitialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params["name"])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    print("\nScheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        contrast_list = []

    # Resume
    start_epoch = 1
    resume_path = os.path.join(log_directory, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=resume_path)
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    print("\nSelected Loss: {}".format(training_params["loss"]["name"]))
    print("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, log_directory)

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            if i == 0 and debugging:
                imed_visualize.save_tensorboard_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    is_three_dim=not model_params["is_2d"])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params["name"] == "HeMISUnet":
            # Increase the probability of a missing modality
            model_params["missing_probability"] **= model_params[
                "missing_probability_growth"]
            dataset_train.update(p=model_params["missing_probability"])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params["name"] == "HeMISUnet":
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params["name"] == "HeMISUnet" or \
                            ('film_layers' in model_params and any(model_params['film_layers'])):
                        metadata = get_metadata(batch["input_metadata"],
                                                model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch["input_metadata"][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                if i == 0 and debugging:
                    imed_visualize.save_tensorboard_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        is_three_dim=not model_params['is_2d'])

                if 'film_layers' in model_params and any(model_params['film_layers']) and debugging and \
                        epoch == num_epochs and i < int(len(dataset_val) / training_params["batch_size"]) + 1:
                    # Store the values of gammas and betas after the last epoch for each batch
                    gammas_dict, betas_dict, contrast_list = store_film_params(
                        gammas_dict, betas_dict, contrast_list,
                        batch['input_metadata'], model,
                        model_params["film_layers"], model_params["depth"])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            val_loss_total_avg = val_loss_total / num_steps
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            tqdm.write(msg)
            end_time = time.time()
            total_time = end_time - start_time
            tqdm.write("Epoch {} took {:.2f} seconds.".format(
                epoch, total_time))

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = os.path.join(log_directory, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    print(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = os.path.join(log_directory, "final_model.pt")
    torch.save(model, final_model_path)
    if 'film_layers' in model_params and any(
            model_params['film_layers']) and debugging:
        save_film_params(gammas_dict, betas_dict, contrast_list,
                         model_params["depth"], log_directory)

    # Save best model in log directory
    if os.path.isfile(resume_path):
        state = torch.load(resume_path)
        model_path = os.path.join(log_directory, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = os.path.join(
                log_directory, model_params["folder_name"],
                model_params["folder_name"] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples, best_model_path)
        except:
            # Save best model in model directory
            best_model_path = os.path.join(log_directory,
                                           model_params["folder_name"],
                                           model_params["folder_name"] + ".pt")
            torch.save(model, best_model_path)
            logger.warning(
                "Failed to save the model as '.onnx', saved it as '.pt': {}".
                format(best_model_path))

    # Save GIFs
    gif_folder = os.path.join(log_directory, "gifs")
    if n_gif > 0 and not os.path.isdir(gif_folder):
        os.makedirs(gif_folder)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split('/')[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split('/')[-1].split(
            ".nii.gz")[0].split(gif_dict["image_path"][i_gif].split('/')[-3] +
                                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = os.path.join(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(path_gif_out)

    writer.close()
    final_time = time.time()
    duration_time = final_time - begin_time
    print('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) +
          "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) +
          "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss
Esempio n. 4
0
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          path_output,
          device,
          wandb_params=None,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        path_output (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the output path.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the path_output) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=path_output)

    # Enable wandb tracking  if the required params are found in the config file and the api key is correct
    wandb_tracking = imed_utils.initialize_wandb(wandb_params)

    if wandb_tracking:
        # Collect all hyperparameters into a dictionary
        cfg = {**training_params, **model_params}

        # Get the actual project, group, and run names if they exist, else choose the temporary names as default
        project_name = wandb_params.get(WandbKW.PROJECT_NAME, "temp_project")
        group_name = wandb_params.get(WandbKW.GROUP_NAME, "temp_group")
        run_name = wandb_params.get(WandbKW.RUN_NAME, "temp_run")

        if project_name == "temp_project" or group_name == "temp_group" or run_name == "temp_run":
            logger.info(
                "{PROJECT/GROUP/RUN} name not found, initializing as {'temp_project'/'temp_group'/'temp_run'}"
            )

        # Initialize WandB with metrics and hyperparameters
        wandb.init(project=project_name,
                   group=group_name,
                   name=run_name,
                   config=cfg)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params[TrainingParamsKW.BALANCE_SAMPLES]
        [BalanceSamplesKW.APPLIED], model_params[ModelParamsKW.NAME] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params[
            TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

    train_loader = DataLoader(
        dataset_train,
        batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
        shuffle=shuffle_train,
        pin_memory=True,
        sampler=sampler_train,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions, training_params[
                TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

        val_loader = DataLoader(
            dataset_val,
            batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
            shuffle=shuffle_val,
            pin_memory=True,
            sampler=sampler_val,
            collate_fn=imed_loader_utils.imed_collate,
            num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]][MetadataKW.INPUT_METADATA][0])
            gif_dict["image_path"].append(
                random_metadata[MetadataKW.INPUT_FILENAMES])
            gif_dict["slice_id"].append(
                random_metadata[MetadataKW.SLICE_INDEX])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        logger.info("Loading pretrained model's weights: {}.")
        logger.info("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        logger.info("Initialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params[ModelParamsKW.NAME])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    logger.info("Scheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Only call wandb methods if required params are found in the config file
    if wandb_tracking:
        # Logs gradients (at every log_freq steps) to the dashboard.
        wandb.watch(model,
                    log="gradients",
                    log_freq=wandb_params["log_grads_every"])

    # Resume
    start_epoch = 1
    resume_path = Path(path_output, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=str(resume_path))
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    logger.info("Selected Loss: {}".format(training_params["loss"]["name"]))
    logger.info("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)
        if wandb_tracking:
            wandb.log({"learning_rate": lr})

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, path_output)

            # RUN MODEL
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                    (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                metadata = get_metadata(batch[MetadataKW.INPUT_METADATA],
                                        model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            # Save image at every 50th step if debugging is true
            if i % 50 == 0 and debugging:
                imed_visualize.save_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    wandb_tracking=wandb_tracking,
                    is_three_dim=not model_params[ModelParamsKW.IS_2D])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        logger.info(msg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
            # Increase the probability of a missing modality
            model_params[ModelParamsKW.MISSING_PROBABILITY] **= model_params[
                ModelParamsKW.MISSING_PROBABILITY_GROWTH]
            dataset_train.update(
                p=model_params[ModelParamsKW.MISSING_PROBABILITY])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                            (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                        metadata = get_metadata(
                            batch[MetadataKW.INPUT_METADATA], model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch[MetadataKW.INPUT_METADATA][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                # Save image at every 10th step if debugging is true
                if i % 50 == 0 and debugging:
                    imed_visualize.save_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        wandb_tracking=wandb_tracking,
                        is_three_dim=not model_params[ModelParamsKW.IS_2D])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            val_loss_total_avg = val_loss_total / num_steps
            # log losses on Tensorboard by default
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            # log on wandb if the corresponding dictionary is provided
            if wandb_tracking:
                wandb.log({"validation-metrics": metrics_dict})
                wandb.log({
                    "losses": {
                        'train_loss': train_loss_total_avg,
                        'val_loss': val_loss_total_avg,
                    }
                })
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            logger.info(msg)
            end_time = time.time()
            total_time = end_time - start_time
            msg_epoch = "Epoch {} took {:.2f} seconds.".format(
                epoch, total_time)
            logger.info(msg_epoch)

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = Path(path_output, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    logger.info(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = Path(path_output, "final_model.pt")
    torch.save(model, final_model_path)

    # Save best model in output path
    if resume_path.is_file():
        state = torch.load(resume_path)
        model_path = Path(path_output, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = Path(
                path_output, model_params[ModelParamsKW.FOLDER_NAME],
                model_params[ModelParamsKW.FOLDER_NAME] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples,
                                       str(best_model_path))
            logger.info(f"Model saved as '.onnx': {best_model_path}")
        except Exception as e:
            logger.warning(f"Failed to save the model as '.onnx': {e}")

        # Save best model as PT in the model directory
        best_model_path = Path(path_output,
                               model_params[ModelParamsKW.FOLDER_NAME],
                               model_params[ModelParamsKW.FOLDER_NAME] + ".pt")
        torch.save(model, best_model_path)
        logger.info(f"Model saved as '.pt': {best_model_path}")

    # Save GIFs
    gif_folder = Path(path_output, "gifs")
    if n_gif > 0 and not gif_folder.is_dir():
        gif_folder.mkdir(parents=True)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split(os.sep)[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split(
            os.sep)[-1].split(".nii.gz")[0].split(
                gif_dict["image_path"][i_gif].split(os.sep)[-3] +
                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = Path(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(str(path_gif_out))

    writer.close()
    wandb.finish()
    final_time = time.time()
    duration_time = final_time - begin_time
    logger.info('begin ' +
                time.strftime('%H:%M:%S', time.localtime(begin_time)) +
                "| End " +
                time.strftime('%H:%M:%S', time.localtime(final_time)) +
                "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss