コード例 #1
0
def test_HeMIS(p=0.0001):
    print('[INFO]: Starting test ... \n')
    training_transform_dict = {
        "Resample":
            {
                "wspace": 0.75,
                "hspace": 0.75
            },
        "CenterCrop":
            {
                "size": [48, 48]
            },
        "NumpyToTensor": {}
    }

    transform_lst, _ = imed_transforms.prepare_transforms(training_transform_dict)

    roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None}

    train_lst = ['sub-unf01']
    contrasts = ['T1w', 'T2w', 'T2star']

    print('[INFO]: Creating dataset ...\n')
    model_params = {
            "name": "HeMISUnet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "missing_probability": 0.00001,
            "missing_probability_growth": 0.9,
            "contrasts": ["T1w", "T2w"],
            "ram": False,
            "path_hdf5": 'testing_data/mytestfile.hdf5',
            "csv_path": 'testing_data/hdf5.csv',
            "target_lst": ["T2w"],
            "roi_lst": ["T2w"]
        }
    contrast_params = {
        "contrast_lst": ['T1w', 'T2w', 'T2star'],
        "balance": {}
    }
    dataset = imed_adaptative.HDF5Dataset(root_dir=PATH_BIDS,
                                          subject_lst=train_lst,
                                          model_params=model_params,
                                          contrast_params=contrast_params,
                                          target_suffix=["_lesion-manual"],
                                          slice_axis=2,
                                          transform=transform_lst,
                                          metadata_choice=False,
                                          dim=2,
                                          slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True,
                                                                                 filter_empty_mask=True),
                                          roi_params=roi_params)

    dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
    print("[INFO]: Dataset RAM status:")
    print(dataset.status)
    print("[INFO]: In memory Dataframe:")
    print(dataset.dataframe)

    # TODO
    # ds_train.filter_roi(nb_nonzero_thr=10)

    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              shuffle=True, pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    model = models.HeMISUnet(contrasts=contrasts,
                             depth=3,
                             drop_rate=DROPOUT,
                             bn_momentum=BN)

    print(model)
    cuda_available = torch.cuda.is_available()

    if cuda_available:
        torch.cuda.set_device(GPU_NUMBER)
        print("Using GPU number {}".format(GPU_NUMBER))
        model.cuda()

    # Initialing Optimizer and scheduler
    step_scheduler_batch = False
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

    load_lst, reload_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [], []

    for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"):
        start_time = time.time()

        start_init = time.time()
        lr = scheduler.get_last_lr()[0]
        model.train()

        tot_init = time.time() - start_init
        init_lst.append(tot_init)

        num_steps = 0
        start_gen = 0
        for i, batch in enumerate(train_loader):
            if i > 0:
                tot_gen = time.time() - start_gen
                gen_lst.append(tot_gen)

            start_load = time.time()
            input_samples, gt_samples = imed_utils.unstack_tensors(batch["input"]), batch["gt"]

            print(batch["input_metadata"][0][0]["missing_mod"])
            missing_mod = imed_training.get_metadata(batch["input_metadata"], model_params)

            print("Number of missing contrasts = {}."
                  .format(len(input_samples) * len(input_samples[0]) - missing_mod.sum()))
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape, gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            tot_load = time.time() - start_load
            load_lst.append(tot_load)

            start_pred = time.time()
            preds = model(var_input, missing_mod)
            tot_pred = time.time() - start_pred
            pred_lst.append(tot_pred)

            start_opt = time.time()
            loss = - losses.DiceLoss()(preds, var_gt)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()

            num_steps += 1
            tot_opt = time.time() - start_opt
            opt_lst.append(tot_opt)

            start_gen = time.time()

        start_schedul = time.time()
        if not step_scheduler_batch:
            scheduler.step()
        tot_schedul = time.time() - start_schedul
        schedul_lst.append(tot_schedul)

        start_reload = time.time()
        print("[INFO]: Updating Dataset")
        p = p ** (2 / 3)
        dataset.update(p=p)
        print("[INFO]: Reloading dataset")
        train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  shuffle=True, pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)
        tot_reload = time.time() - start_reload
        reload_lst.append(tot_reload)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

    print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst)))
    print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst)))
    print('Mean SD reload {} -- {}'.format(np.mean(reload_lst), np.std(reload_lst)))
    print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst)))
    print('Mean SD opt {} --  {}'.format(np.mean(opt_lst), np.std(opt_lst)))
    print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst)))
    print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst), np.std(schedul_lst)))
コード例 #2
0
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          log_directory,
          device,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        log_directory (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the log directory.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=log_directory)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params["balance_samples"]["applied"],
        model_params["name"] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params['balance_samples']['type'])

    train_loader = DataLoader(dataset_train,
                              batch_size=training_params["batch_size"],
                              shuffle=shuffle_train,
                              pin_memory=True,
                              sampler=sampler_train,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions,
            training_params['balance_samples']['type'])

        val_loader = DataLoader(dataset_val,
                                batch_size=training_params["batch_size"],
                                shuffle=shuffle_val,
                                pin_memory=True,
                                sampler=sampler_val,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]]["input_metadata"][0])
            gif_dict["image_path"].append(random_metadata['input_filenames'])
            gif_dict["slice_id"].append(random_metadata['slice_index'])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        print("\nLoading pretrained model's weights: {}.")
        print("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        print("\nInitialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params["name"])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    print("\nScheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        contrast_list = []

    # Resume
    start_epoch = 1
    resume_path = os.path.join(log_directory, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=resume_path)
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    print("\nSelected Loss: {}".format(training_params["loss"]["name"]))
    print("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, log_directory)

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            if i == 0 and debugging:
                imed_visualize.save_tensorboard_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    is_three_dim=not model_params["is_2d"])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params["name"] == "HeMISUnet":
            # Increase the probability of a missing modality
            model_params["missing_probability"] **= model_params[
                "missing_probability_growth"]
            dataset_train.update(p=model_params["missing_probability"])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params["name"] == "HeMISUnet":
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params["name"] == "HeMISUnet" or \
                            ('film_layers' in model_params and any(model_params['film_layers'])):
                        metadata = get_metadata(batch["input_metadata"],
                                                model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch["input_metadata"][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                if i == 0 and debugging:
                    imed_visualize.save_tensorboard_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        is_three_dim=not model_params['is_2d'])

                if 'film_layers' in model_params and any(model_params['film_layers']) and debugging and \
                        epoch == num_epochs and i < int(len(dataset_val) / training_params["batch_size"]) + 1:
                    # Store the values of gammas and betas after the last epoch for each batch
                    gammas_dict, betas_dict, contrast_list = store_film_params(
                        gammas_dict, betas_dict, contrast_list,
                        batch['input_metadata'], model,
                        model_params["film_layers"], model_params["depth"])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            val_loss_total_avg = val_loss_total / num_steps
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            tqdm.write(msg)
            end_time = time.time()
            total_time = end_time - start_time
            tqdm.write("Epoch {} took {:.2f} seconds.".format(
                epoch, total_time))

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = os.path.join(log_directory, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    print(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = os.path.join(log_directory, "final_model.pt")
    torch.save(model, final_model_path)
    if 'film_layers' in model_params and any(
            model_params['film_layers']) and debugging:
        save_film_params(gammas_dict, betas_dict, contrast_list,
                         model_params["depth"], log_directory)

    # Save best model in log directory
    if os.path.isfile(resume_path):
        state = torch.load(resume_path)
        model_path = os.path.join(log_directory, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = os.path.join(
                log_directory, model_params["folder_name"],
                model_params["folder_name"] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples, best_model_path)
        except:
            # Save best model in model directory
            best_model_path = os.path.join(log_directory,
                                           model_params["folder_name"],
                                           model_params["folder_name"] + ".pt")
            torch.save(model, best_model_path)
            logger.warning(
                "Failed to save the model as '.onnx', saved it as '.pt': {}".
                format(best_model_path))

    # Save GIFs
    gif_folder = os.path.join(log_directory, "gifs")
    if n_gif > 0 and not os.path.isdir(gif_folder):
        os.makedirs(gif_folder)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split('/')[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split('/')[-1].split(
            ".nii.gz")[0].split(gif_dict["image_path"][i_gif].split('/')[-3] +
                                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = os.path.join(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(path_gif_out)

    writer.close()
    final_time = time.time()
    duration_time = final_time - begin_time
    print('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) +
          "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) +
          "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss
コード例 #3
0
ファイル: testing.py プロジェクト: mpompolas/ivadomed
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list = [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] in ["HeMISUnet", "FiLMedUnet"]:
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "UNet3D" and model_params[
                "attention"] and ofolder:
            imed_utils.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if not model_params["name"].endswith('3D'):
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = metadata_idx[0]['gt_filenames'][0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_tmp.split('/')[-1])
                        fname_pred = fname_pred.rsplit(
                            testing_params['target_suffix'][0],
                            1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=testing_params["binarize_prediction"])
                    # TODO: Adapt to multilabel
                    output_data = output_nii.get_fdata()[:, :, :, 0]
                    preds_npy_list.append(output_data)

                    gt_npy_list.append(nib.load(fname_tmp).get_fdata())

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            np.stack(pred_tmp_lst, -1),
                            testing_params["binarize_prediction"] > 0,
                            fname_tmp,
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_utils.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=testing_params["binarize_prediction"])
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt_lst = []
                    for gt in metadata[0]['gt_filenames']:
                        # For multi-label, if all labels are not in every image
                        if gt is not None:
                            gt_lst.append(nib.load(gt).get_fdata())
                        else:
                            gt_lst.append(np.zeros(gt_lst[0].shape))

                    gt_npy_list.append(np.array(gt_lst))
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            pred_undo,
                            testing_params['binarize_prediction'] > 0,
                            batch['input_metadata'][smp_idx][0]
                            ['input_filenames'],
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            slice_axis)

    return preds_npy_list, gt_npy_list
コード例 #4
0
ファイル: testing.py プロジェクト: AmmieQi/ivadomed
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None,
                  postprocessing=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.
        postprocessing (dict): Indicates postprocessing steps.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list, filenames = [], [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        # 2 * model_params["depth"] + 2 is the number of FiLM layers. 1 is added since the range starts at one.
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        metadata_values_lst = []

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "Modified3DUNet" and model_params[
                "attention"] and ofolder:
            imed_visualize.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        if 'film_layers' in model_params and any(model_params['film_layers']):
            # Store the values of gammas and betas after the last epoch for each batch
            gammas_dict, betas_dict, metadata_values_lst = store_film_params(
                gammas_dict, betas_dict, metadata_values_lst,
                batch['input_metadata'], model, model_params["film_layers"],
                model_params["depth"], model_params['metadata'])

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if model_params["is_2d"]:
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = list(filter(None,
                                        metadata_idx[0]['gt_filenames']))[0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  Path(fname_tmp).name)
                        fname_pred = fname_pred.rsplit("_",
                                                       1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(filenames)
                    gt_npy_list.append(gt)

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(np.stack(pred_tmp_lst, -1),
                        #                              False,
                        #                              fname_tmp,
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref
                filenames = metadata_idx[0]['gt_filenames']

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_inference.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(metadata[0]['gt_filenames'])
                    gt_npy_list.append(gt)
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(pred_undo,
                        #                              False,
                        #                              batch['input_metadata'][smp_idx][0]['input_filenames'],
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              slice_axis)

    if 'film_layers' in model_params and any(model_params['film_layers']):
        save_film_params(gammas_dict, betas_dict, metadata_values_lst,
                         model_params["depth"],
                         ofolder.replace("pred_masks", ""))
    return preds_npy_list, gt_npy_list
コード例 #5
0
ファイル: training.py プロジェクト: mathieuboudreau/ivadomed
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          path_output,
          device,
          wandb_params=None,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        path_output (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the output path.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the path_output) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=path_output)

    # Enable wandb tracking  if the required params are found in the config file and the api key is correct
    wandb_tracking = imed_utils.initialize_wandb(wandb_params)

    if wandb_tracking:
        # Collect all hyperparameters into a dictionary
        cfg = {**training_params, **model_params}

        # Get the actual project, group, and run names if they exist, else choose the temporary names as default
        project_name = wandb_params.get(WandbKW.PROJECT_NAME, "temp_project")
        group_name = wandb_params.get(WandbKW.GROUP_NAME, "temp_group")
        run_name = wandb_params.get(WandbKW.RUN_NAME, "temp_run")

        if project_name == "temp_project" or group_name == "temp_group" or run_name == "temp_run":
            logger.info(
                "{PROJECT/GROUP/RUN} name not found, initializing as {'temp_project'/'temp_group'/'temp_run'}"
            )

        # Initialize WandB with metrics and hyperparameters
        wandb.init(project=project_name,
                   group=group_name,
                   name=run_name,
                   config=cfg)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params[TrainingParamsKW.BALANCE_SAMPLES]
        [BalanceSamplesKW.APPLIED], model_params[ModelParamsKW.NAME] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params[
            TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

    train_loader = DataLoader(
        dataset_train,
        batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
        shuffle=shuffle_train,
        pin_memory=True,
        sampler=sampler_train,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions, training_params[
                TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

        val_loader = DataLoader(
            dataset_val,
            batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
            shuffle=shuffle_val,
            pin_memory=True,
            sampler=sampler_val,
            collate_fn=imed_loader_utils.imed_collate,
            num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]][MetadataKW.INPUT_METADATA][0])
            gif_dict["image_path"].append(
                random_metadata[MetadataKW.INPUT_FILENAMES])
            gif_dict["slice_id"].append(
                random_metadata[MetadataKW.SLICE_INDEX])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        logger.info("Loading pretrained model's weights: {}.")
        logger.info("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        logger.info("Initialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params[ModelParamsKW.NAME])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    logger.info("Scheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Only call wandb methods if required params are found in the config file
    if wandb_tracking:
        # Logs gradients (at every log_freq steps) to the dashboard.
        wandb.watch(model,
                    log="gradients",
                    log_freq=wandb_params["log_grads_every"])

    # Resume
    start_epoch = 1
    resume_path = Path(path_output, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=str(resume_path))
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    logger.info("Selected Loss: {}".format(training_params["loss"]["name"]))
    logger.info("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)
        if wandb_tracking:
            wandb.log({"learning_rate": lr})

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, path_output)

            # RUN MODEL
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                    (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                metadata = get_metadata(batch[MetadataKW.INPUT_METADATA],
                                        model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            # Save image at every 50th step if debugging is true
            if i % 50 == 0 and debugging:
                imed_visualize.save_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    wandb_tracking=wandb_tracking,
                    is_three_dim=not model_params[ModelParamsKW.IS_2D])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        logger.info(msg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
            # Increase the probability of a missing modality
            model_params[ModelParamsKW.MISSING_PROBABILITY] **= model_params[
                ModelParamsKW.MISSING_PROBABILITY_GROWTH]
            dataset_train.update(
                p=model_params[ModelParamsKW.MISSING_PROBABILITY])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                            (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                        metadata = get_metadata(
                            batch[MetadataKW.INPUT_METADATA], model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch[MetadataKW.INPUT_METADATA][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                # Save image at every 10th step if debugging is true
                if i % 50 == 0 and debugging:
                    imed_visualize.save_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        wandb_tracking=wandb_tracking,
                        is_three_dim=not model_params[ModelParamsKW.IS_2D])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            val_loss_total_avg = val_loss_total / num_steps
            # log losses on Tensorboard by default
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            # log on wandb if the corresponding dictionary is provided
            if wandb_tracking:
                wandb.log({"validation-metrics": metrics_dict})
                wandb.log({
                    "losses": {
                        'train_loss': train_loss_total_avg,
                        'val_loss': val_loss_total_avg,
                    }
                })
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            logger.info(msg)
            end_time = time.time()
            total_time = end_time - start_time
            msg_epoch = "Epoch {} took {:.2f} seconds.".format(
                epoch, total_time)
            logger.info(msg_epoch)

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = Path(path_output, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    logger.info(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = Path(path_output, "final_model.pt")
    torch.save(model, final_model_path)

    # Save best model in output path
    if resume_path.is_file():
        state = torch.load(resume_path)
        model_path = Path(path_output, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = Path(
                path_output, model_params[ModelParamsKW.FOLDER_NAME],
                model_params[ModelParamsKW.FOLDER_NAME] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples,
                                       str(best_model_path))
            logger.info(f"Model saved as '.onnx': {best_model_path}")
        except Exception as e:
            logger.warning(f"Failed to save the model as '.onnx': {e}")

        # Save best model as PT in the model directory
        best_model_path = Path(path_output,
                               model_params[ModelParamsKW.FOLDER_NAME],
                               model_params[ModelParamsKW.FOLDER_NAME] + ".pt")
        torch.save(model, best_model_path)
        logger.info(f"Model saved as '.pt': {best_model_path}")

    # Save GIFs
    gif_folder = Path(path_output, "gifs")
    if n_gif > 0 and not gif_folder.is_dir():
        gif_folder.mkdir(parents=True)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split(os.sep)[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split(
            os.sep)[-1].split(".nii.gz")[0].split(
                gif_dict["image_path"][i_gif].split(os.sep)[-3] +
                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = Path(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(str(path_gif_out))

    writer.close()
    wandb.finish()
    final_time = time.time()
    duration_time = final_time - begin_time
    logger.info('begin ' +
                time.strftime('%H:%M:%S', time.localtime(begin_time)) +
                "| End " +
                time.strftime('%H:%M:%S', time.localtime(final_time)) +
                "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss