Пример #1
0
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          log_directory,
          device,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        log_directory (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the log directory.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=log_directory)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params["balance_samples"]["applied"],
        model_params["name"] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params['balance_samples']['type'])

    train_loader = DataLoader(dataset_train,
                              batch_size=training_params["batch_size"],
                              shuffle=shuffle_train,
                              pin_memory=True,
                              sampler=sampler_train,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions,
            training_params['balance_samples']['type'])

        val_loader = DataLoader(dataset_val,
                                batch_size=training_params["batch_size"],
                                shuffle=shuffle_val,
                                pin_memory=True,
                                sampler=sampler_val,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]]["input_metadata"][0])
            gif_dict["image_path"].append(random_metadata['input_filenames'])
            gif_dict["slice_id"].append(random_metadata['slice_index'])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        print("\nLoading pretrained model's weights: {}.")
        print("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        print("\nInitialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params["name"])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    print("\nScheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        contrast_list = []

    # Resume
    start_epoch = 1
    resume_path = os.path.join(log_directory, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=resume_path)
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    print("\nSelected Loss: {}".format(training_params["loss"]["name"]))
    print("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, log_directory)

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            if i == 0 and debugging:
                imed_visualize.save_tensorboard_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    is_three_dim=not model_params["is_2d"])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params["name"] == "HeMISUnet":
            # Increase the probability of a missing modality
            model_params["missing_probability"] **= model_params[
                "missing_probability_growth"]
            dataset_train.update(p=model_params["missing_probability"])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params["name"] == "HeMISUnet":
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params["name"] == "HeMISUnet" or \
                            ('film_layers' in model_params and any(model_params['film_layers'])):
                        metadata = get_metadata(batch["input_metadata"],
                                                model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch["input_metadata"][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                if i == 0 and debugging:
                    imed_visualize.save_tensorboard_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        is_three_dim=not model_params['is_2d'])

                if 'film_layers' in model_params and any(model_params['film_layers']) and debugging and \
                        epoch == num_epochs and i < int(len(dataset_val) / training_params["batch_size"]) + 1:
                    # Store the values of gammas and betas after the last epoch for each batch
                    gammas_dict, betas_dict, contrast_list = store_film_params(
                        gammas_dict, betas_dict, contrast_list,
                        batch['input_metadata'], model,
                        model_params["film_layers"], model_params["depth"])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            val_loss_total_avg = val_loss_total / num_steps
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            tqdm.write(msg)
            end_time = time.time()
            total_time = end_time - start_time
            tqdm.write("Epoch {} took {:.2f} seconds.".format(
                epoch, total_time))

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = os.path.join(log_directory, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    print(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = os.path.join(log_directory, "final_model.pt")
    torch.save(model, final_model_path)
    if 'film_layers' in model_params and any(
            model_params['film_layers']) and debugging:
        save_film_params(gammas_dict, betas_dict, contrast_list,
                         model_params["depth"], log_directory)

    # Save best model in log directory
    if os.path.isfile(resume_path):
        state = torch.load(resume_path)
        model_path = os.path.join(log_directory, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = os.path.join(
                log_directory, model_params["folder_name"],
                model_params["folder_name"] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples, best_model_path)
        except:
            # Save best model in model directory
            best_model_path = os.path.join(log_directory,
                                           model_params["folder_name"],
                                           model_params["folder_name"] + ".pt")
            torch.save(model, best_model_path)
            logger.warning(
                "Failed to save the model as '.onnx', saved it as '.pt': {}".
                format(best_model_path))

    # Save GIFs
    gif_folder = os.path.join(log_directory, "gifs")
    if n_gif > 0 and not os.path.isdir(gif_folder):
        os.makedirs(gif_folder)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split('/')[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split('/')[-1].split(
            ".nii.gz")[0].split(gif_dict["image_path"][i_gif].split('/')[-3] +
                                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = os.path.join(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(path_gif_out)

    writer.close()
    final_time = time.time()
    duration_time = final_time - begin_time
    print('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) +
          "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) +
          "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss
Пример #2
0
def threshold_analysis(model_path,
                       ds_lst,
                       model_params,
                       testing_params,
                       metric="dice",
                       increment=0.1,
                       fname_out="thr.png",
                       cuda_available=True):
    """Run a threshold analysis to find the optimal threshold on a sub-dataset.

    Args:
        model_path (str): Model path.
        ds_lst (list): List of loaders.
        model_params (dict): Model's parameters.
        testing_params (dict): Testing parameters
        metric (str): Choice between "dice" and "recall_specificity". If "recall_specificity", then a ROC analysis
            is performed.
        increment (float): Increment between tested thresholds.
        fname_out (str): Plot output filename.
        cuda_available (bool): If True, CUDA is available.

    Returns:
        float: optimal threshold.
    """
    if metric not in ["dice", "recall_specificity"]:
        raise ValueError(
            '\nChoice of metric for threshold analysis: dice, recall_specificity.'
        )

    # Adjust some testing parameters
    testing_params["uncertainty"]["applied"] = False

    # Load model
    model = torch.load(model_path)
    # Eval mode
    model.eval()

    # List of thresholds
    thr_list = list(np.arange(0.0, 1.0, increment))[1:]

    # Init metric manager for each thr
    metric_fns = [
        imed_metrics.recall_score, imed_metrics.dice_score,
        imed_metrics.specificity_score
    ]
    metric_dict = {
        thr: imed_metrics.MetricManager(metric_fns)
        for thr in thr_list
    }

    # Load
    loader = DataLoader(ConcatDataset(ds_lst),
                        batch_size=testing_params["batch_size"],
                        shuffle=False,
                        pin_memory=True,
                        sampler=None,
                        collate_fn=imed_loader_utils.imed_collate,
                        num_workers=0)

    # Run inference
    preds_npy, gt_npy = run_inference(loader,
                                      model,
                                      model_params,
                                      testing_params,
                                      ofolder=None,
                                      cuda_available=cuda_available)

    print('\nRunning threshold analysis to find optimal threshold')
    # Make sure the GT is binarized
    gt_npy = [threshold_predictions(gt, thr=0.5) for gt in gt_npy]
    # Move threshold
    for thr in tqdm(thr_list, desc="Search"):
        preds_thr = [
            threshold_predictions(copy.deepcopy(pred), thr=thr)
            for pred in preds_npy
        ]
        metric_dict[thr](preds_thr, gt_npy)

    # Get results
    tpr_list, fpr_list, dice_list = [], [], []
    for thr in thr_list:
        result_thr = metric_dict[thr].get_results()
        tpr_list.append(result_thr["recall_score"])
        fpr_list.append(1 - result_thr["specificity_score"])
        dice_list.append(result_thr["dice_score"])

    # Get optimal threshold
    if metric == "dice":
        diff_list = dice_list
    else:
        diff_list = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)]

    optimal_idx = np.max(np.where(diff_list == np.max(diff_list)))
    optimal_threshold = thr_list[optimal_idx]
    print('\tOptimal threshold: {}'.format(optimal_threshold))

    # Save plot
    print('\tSaving plot: {}'.format(fname_out))
    if metric == "dice":
        # Run plot
        imed_metrics.plot_dice_thr(thr_list, dice_list, optimal_idx, fname_out)
    else:
        # Add 0 and 1 as extrema
        tpr_list = [0.0] + tpr_list + [1.0]
        fpr_list = [0.0] + fpr_list + [1.0]
        optimal_idx += 1
        # Run plot
        imed_metrics.plot_roc_curve(tpr_list, fpr_list, optimal_idx, fname_out)

    return optimal_threshold
Пример #3
0
def test_inference(transforms_dict, test_lst, target_lst, roi_params, testing_params):
    cuda_available, device = imed_utils.define_device(GPU_ID)

    model_params = {"name": "Unet", "is_2d": True}
    loader_params = {
        "transforms_params": transforms_dict,
        "data_list": test_lst,
        "dataset_type": "testing",
        "requires_undo": True,
        "contrast_params": {"contrast_lst": ['T2w'], "balance": {}},
        "path_data": [__data_testing_dir__],
        "target_suffix": target_lst,
        "roi_params": roi_params,
        "slice_filter_params": {
            "filter_empty_mask": False,
            "filter_empty_input": True
        },
        "slice_axis": SLICE_AXIS,
        "multichannel": False
    }
    loader_params.update({"model_params": model_params})

    # Get Testing dataset
    ds_test = imed_loader.load_dataset(**loader_params)
    test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE,
                             shuffle=False, pin_memory=True,
                             collate_fn=imed_loader_utils.imed_collate,
                             num_workers=0)

    # Undo transform
    val_undo_transform = imed_transforms.UndoCompose(imed_transforms.Compose(transforms_dict))

    # Update testing_params
    testing_params.update({
        "slice_axis": loader_params["slice_axis"],
        "target_suffix": loader_params["target_suffix"],
        "undo_transforms": val_undo_transform
    })

    # Model
    model = imed_models.Unet()

    if cuda_available:
        model.cuda()
    model.eval()

    metric_fns = [imed_metrics.dice_score,
                  imed_metrics.hausdorff_score,
                  imed_metrics.precision_score,
                  imed_metrics.recall_score,
                  imed_metrics.specificity_score,
                  imed_metrics.intersection_over_union,
                  imed_metrics.accuracy_score]

    metric_mgr = imed_metrics.MetricManager(metric_fns)

    if not os.path.isdir(__output_dir__):
        os.makedirs(__output_dir__)

    preds_npy, gt_npy = imed_testing.run_inference(test_loader=test_loader,
                                                   model=model,
                                                   model_params=model_params,
                                                   testing_params=testing_params,
                                                   ofolder=__output_dir__,
                                                   cuda_available=cuda_available)

    metric_mgr(preds_npy, gt_npy)
    metrics_dict = metric_mgr.get_results()
    metric_mgr.reset()
    print(metrics_dict)
Пример #4
0
def test(model_params,
         dataset_test,
         testing_params,
         path_output,
         device,
         cuda_available=True,
         metric_fns=None,
         postprocessing=None):
    """Main command to test the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_test (imed_loader): Testing dataset.
        testing_params (dict): Testing parameters.
        path_output (str): Folder where predictions are saved.
        device (torch.device): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        postprocessing (dict): Contains postprocessing steps.

    Returns:
        dict: result metrics.
    """
    # DATA LOADER
    test_loader = DataLoader(dataset_test,
                             batch_size=testing_params["batch_size"],
                             shuffle=False,
                             pin_memory=True,
                             collate_fn=imed_loader_utils.imed_collate,
                             num_workers=0)

    # LOAD TRAIN MODEL
    fname_model = os.path.join(path_output, "best_model.pt")
    print('\nLoading model: {}'.format(fname_model))
    model = torch.load(fname_model, map_location=device)
    if cuda_available:
        model.cuda()
    model.eval()

    # CREATE OUTPUT FOLDER
    path_3Dpred = os.path.join(path_output, 'pred_masks')
    if not os.path.isdir(path_3Dpred):
        os.makedirs(path_3Dpred)

    # METRIC MANAGER
    metric_mgr = imed_metrics.MetricManager(metric_fns)

    # UNCERTAINTY SETTINGS
    if (testing_params['uncertainty']['epistemic'] or testing_params['uncertainty']['aleatoric']) and \
            testing_params['uncertainty']['n_it'] > 0:
        n_monteCarlo = testing_params['uncertainty']['n_it'] + 1
        testing_params['uncertainty']['applied'] = True
        print('\nComputing model uncertainty over {} iterations.'.format(
            n_monteCarlo - 1))
    else:
        testing_params['uncertainty']['applied'] = False
        n_monteCarlo = 1

    for i_monteCarlo in range(n_monteCarlo):
        preds_npy, gt_npy = run_inference(test_loader, model, model_params,
                                          testing_params, path_3Dpred,
                                          cuda_available, i_monteCarlo,
                                          postprocessing)
        metric_mgr(preds_npy, gt_npy)
        # If uncertainty computation, don't apply it on last iteration for prediction
        if testing_params['uncertainty']['applied'] and (n_monteCarlo - 2
                                                         == i_monteCarlo):
            testing_params['uncertainty']['applied'] = False
            # COMPUTE UNCERTAINTY MAPS
            imed_uncertainty.run_uncertainty(ifolder=path_3Dpred)

    metrics_dict = metric_mgr.get_results()
    metric_mgr.reset()
    print(metrics_dict)
    return metrics_dict
Пример #5
0
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          path_output,
          device,
          wandb_params=None,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        path_output (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the output path.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the path_output) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=path_output)

    # Enable wandb tracking  if the required params are found in the config file and the api key is correct
    wandb_tracking = imed_utils.initialize_wandb(wandb_params)

    if wandb_tracking:
        # Collect all hyperparameters into a dictionary
        cfg = {**training_params, **model_params}

        # Get the actual project, group, and run names if they exist, else choose the temporary names as default
        project_name = wandb_params.get(WandbKW.PROJECT_NAME, "temp_project")
        group_name = wandb_params.get(WandbKW.GROUP_NAME, "temp_group")
        run_name = wandb_params.get(WandbKW.RUN_NAME, "temp_run")

        if project_name == "temp_project" or group_name == "temp_group" or run_name == "temp_run":
            logger.info(
                "{PROJECT/GROUP/RUN} name not found, initializing as {'temp_project'/'temp_group'/'temp_run'}"
            )

        # Initialize WandB with metrics and hyperparameters
        wandb.init(project=project_name,
                   group=group_name,
                   name=run_name,
                   config=cfg)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params[TrainingParamsKW.BALANCE_SAMPLES]
        [BalanceSamplesKW.APPLIED], model_params[ModelParamsKW.NAME] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params[
            TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

    train_loader = DataLoader(
        dataset_train,
        batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
        shuffle=shuffle_train,
        pin_memory=True,
        sampler=sampler_train,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions, training_params[
                TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

        val_loader = DataLoader(
            dataset_val,
            batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
            shuffle=shuffle_val,
            pin_memory=True,
            sampler=sampler_val,
            collate_fn=imed_loader_utils.imed_collate,
            num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]][MetadataKW.INPUT_METADATA][0])
            gif_dict["image_path"].append(
                random_metadata[MetadataKW.INPUT_FILENAMES])
            gif_dict["slice_id"].append(
                random_metadata[MetadataKW.SLICE_INDEX])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        logger.info("Loading pretrained model's weights: {}.")
        logger.info("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        logger.info("Initialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params[ModelParamsKW.NAME])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    logger.info("Scheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Only call wandb methods if required params are found in the config file
    if wandb_tracking:
        # Logs gradients (at every log_freq steps) to the dashboard.
        wandb.watch(model,
                    log="gradients",
                    log_freq=wandb_params["log_grads_every"])

    # Resume
    start_epoch = 1
    resume_path = Path(path_output, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=str(resume_path))
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    logger.info("Selected Loss: {}".format(training_params["loss"]["name"]))
    logger.info("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)
        if wandb_tracking:
            wandb.log({"learning_rate": lr})

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, path_output)

            # RUN MODEL
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                    (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                metadata = get_metadata(batch[MetadataKW.INPUT_METADATA],
                                        model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            # Save image at every 50th step if debugging is true
            if i % 50 == 0 and debugging:
                imed_visualize.save_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    wandb_tracking=wandb_tracking,
                    is_three_dim=not model_params[ModelParamsKW.IS_2D])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        logger.info(msg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
            # Increase the probability of a missing modality
            model_params[ModelParamsKW.MISSING_PROBABILITY] **= model_params[
                ModelParamsKW.MISSING_PROBABILITY_GROWTH]
            dataset_train.update(
                p=model_params[ModelParamsKW.MISSING_PROBABILITY])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                            (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                        metadata = get_metadata(
                            batch[MetadataKW.INPUT_METADATA], model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch[MetadataKW.INPUT_METADATA][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                # Save image at every 10th step if debugging is true
                if i % 50 == 0 and debugging:
                    imed_visualize.save_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        wandb_tracking=wandb_tracking,
                        is_three_dim=not model_params[ModelParamsKW.IS_2D])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            val_loss_total_avg = val_loss_total / num_steps
            # log losses on Tensorboard by default
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            # log on wandb if the corresponding dictionary is provided
            if wandb_tracking:
                wandb.log({"validation-metrics": metrics_dict})
                wandb.log({
                    "losses": {
                        'train_loss': train_loss_total_avg,
                        'val_loss': val_loss_total_avg,
                    }
                })
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            logger.info(msg)
            end_time = time.time()
            total_time = end_time - start_time
            msg_epoch = "Epoch {} took {:.2f} seconds.".format(
                epoch, total_time)
            logger.info(msg_epoch)

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = Path(path_output, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    logger.info(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = Path(path_output, "final_model.pt")
    torch.save(model, final_model_path)

    # Save best model in output path
    if resume_path.is_file():
        state = torch.load(resume_path)
        model_path = Path(path_output, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = Path(
                path_output, model_params[ModelParamsKW.FOLDER_NAME],
                model_params[ModelParamsKW.FOLDER_NAME] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples,
                                       str(best_model_path))
            logger.info(f"Model saved as '.onnx': {best_model_path}")
        except Exception as e:
            logger.warning(f"Failed to save the model as '.onnx': {e}")

        # Save best model as PT in the model directory
        best_model_path = Path(path_output,
                               model_params[ModelParamsKW.FOLDER_NAME],
                               model_params[ModelParamsKW.FOLDER_NAME] + ".pt")
        torch.save(model, best_model_path)
        logger.info(f"Model saved as '.pt': {best_model_path}")

    # Save GIFs
    gif_folder = Path(path_output, "gifs")
    if n_gif > 0 and not gif_folder.is_dir():
        gif_folder.mkdir(parents=True)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split(os.sep)[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split(
            os.sep)[-1].split(".nii.gz")[0].split(
                gif_dict["image_path"][i_gif].split(os.sep)[-3] +
                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = Path(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(str(path_gif_out))

    writer.close()
    wandb.finish()
    final_time = time.time()
    duration_time = final_time - begin_time
    logger.info('begin ' +
                time.strftime('%H:%M:%S', time.localtime(begin_time)) +
                "| End " +
                time.strftime('%H:%M:%S', time.localtime(final_time)) +
                "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss