Exemple #1
0
    def __call__(self, sample):
        input_data, gt_data = sample['input'], sample['gt']

        if self.filter_empty_mask:
            # Filter slices that do not have ANY ground truth (i.e. all masks are empty)
            if not np.any(gt_data):
                return False

        if self.filter_absent_class:
            # Filter slices that have absent classes (i.e. one or more masks are empty)
            if not np.all([np.any(mask) for mask in gt_data]):
                return False

        if self.filter_empty_input:
            # Filter set of images if one of them is empty or filled with constant value (i.e. std == 0)
            if np.any([img.std() == 0 for img in input_data]):
                return False

        if self.filter_classification:
            if not np.all([
                    int(
                        self.classifier(
                            imed_utils.cuda(
                                torch.from_numpy(
                                    img.copy()).unsqueeze(0).unsqueeze(0),
                                self.cuda_available))) for img in input_data
            ]):
                return False

        return True
Exemple #2
0
def get_preds(context: dict, path_model_file: str, model_params: dict,
              gpu_id: int, batch: dict) -> tensor:
    """Returns the predictions from the given model.

    Args:
        context (dict): configuration dict.
        path_model_file (str): name of file containing model.
        model_params (dict): dictionary containing model parameters.
        gpu_id (int): Number representing gpu number if available. Currently does NOT support multiple GPU segmentation.
        batch (dict): dictionary containing input, gt and metadata

    Returns:
        tensor: predictions from the model.
    """
    # Define device
    cuda_available, device = imed_utils.define_device(gpu_id)

    with torch.no_grad():

        # Load the Input
        img = imed_utils.cuda(batch['input'], cuda_available=cuda_available)

        # Load the PyTorch model and evaluate if model files exist.
        if path_model_file.lower().endswith('.pt'):
            logger.debug(f"PyTorch model detected at: {path_model_file}")
            logger.debug(f"Loading model from: {path_model_file}")
            model = torch.load(path_model_file, map_location=device)
            # Inference time
            logger.debug(f"Evaluating model: {path_model_file}")
            model.eval()

            # Films/Hemis based prediction require meta data load
            if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \
                    ('HeMISUnet' in context and context['HeMISUnet']['applied']):
                # Load meta data before prediction
                metadata = imed_training.get_metadata(batch["input_metadata"],
                                                      model_params)
                preds = model(img, metadata)
            else:
                preds = model(img)
        # Otherwise, Onnex Inference (PyTorch can't load .onnx)
        else:
            logger.debug(f"Likely ONNX model detected at: {path_model_file}")
            logger.debug(f"Conduct ONNX model inference... ")
            preds = onnx_inference(path_model_file, img)

        logger.debug("Sending predictions to CPU")
        # Move prediction to CPU
        preds = preds.cpu()

    return preds
Exemple #3
0
    def __call__(self, sample):
        input_data, gt_data = sample['input'], sample['gt']

        if self.filter_empty_mask:
            if not np.any(gt_data):
                return False

        if self.filter_empty_input:
            # Filter set of images if one of them is empty or filled with constant value (i.e. std == 0)
            if np.any([img.std() == 0 for img in input_data]):
                return False

        if self.filter_classification:
            if not np.all([int(
                    self.classifier(imed_utils.cuda(torch.from_numpy(img.copy()).unsqueeze(0).unsqueeze(0), self.cuda_available)))
                           for img in input_data]):
                return False

        return True
Exemple #4
0
def test_HeMIS(p=0.0001):
    print('[INFO]: Starting test ... \n')
    training_transform_dict = {
        "Resample":
            {
                "wspace": 0.75,
                "hspace": 0.75
            },
        "CenterCrop":
            {
                "size": [48, 48]
            },
        "NumpyToTensor": {}
    }

    transform_lst, _ = imed_transforms.prepare_transforms(training_transform_dict)

    roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None}

    train_lst = ['sub-unf01']
    contrasts = ['T1w', 'T2w', 'T2star']

    print('[INFO]: Creating dataset ...\n')
    model_params = {
            "name": "HeMISUnet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "missing_probability": 0.00001,
            "missing_probability_growth": 0.9,
            "contrasts": ["T1w", "T2w"],
            "ram": False,
            "path_hdf5": 'testing_data/mytestfile.hdf5',
            "csv_path": 'testing_data/hdf5.csv',
            "target_lst": ["T2w"],
            "roi_lst": ["T2w"]
        }
    contrast_params = {
        "contrast_lst": ['T1w', 'T2w', 'T2star'],
        "balance": {}
    }
    dataset = imed_adaptative.HDF5Dataset(root_dir=PATH_BIDS,
                                          subject_lst=train_lst,
                                          model_params=model_params,
                                          contrast_params=contrast_params,
                                          target_suffix=["_lesion-manual"],
                                          slice_axis=2,
                                          transform=transform_lst,
                                          metadata_choice=False,
                                          dim=2,
                                          slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True,
                                                                                 filter_empty_mask=True),
                                          roi_params=roi_params)

    dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
    print("[INFO]: Dataset RAM status:")
    print(dataset.status)
    print("[INFO]: In memory Dataframe:")
    print(dataset.dataframe)

    # TODO
    # ds_train.filter_roi(nb_nonzero_thr=10)

    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              shuffle=True, pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    model = models.HeMISUnet(contrasts=contrasts,
                             depth=3,
                             drop_rate=DROPOUT,
                             bn_momentum=BN)

    print(model)
    cuda_available = torch.cuda.is_available()

    if cuda_available:
        torch.cuda.set_device(GPU_NUMBER)
        print("Using GPU number {}".format(GPU_NUMBER))
        model.cuda()

    # Initialing Optimizer and scheduler
    step_scheduler_batch = False
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

    load_lst, reload_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [], []

    for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"):
        start_time = time.time()

        start_init = time.time()
        lr = scheduler.get_last_lr()[0]
        model.train()

        tot_init = time.time() - start_init
        init_lst.append(tot_init)

        num_steps = 0
        start_gen = 0
        for i, batch in enumerate(train_loader):
            if i > 0:
                tot_gen = time.time() - start_gen
                gen_lst.append(tot_gen)

            start_load = time.time()
            input_samples, gt_samples = imed_utils.unstack_tensors(batch["input"]), batch["gt"]

            print(batch["input_metadata"][0][0]["missing_mod"])
            missing_mod = imed_training.get_metadata(batch["input_metadata"], model_params)

            print("Number of missing contrasts = {}."
                  .format(len(input_samples) * len(input_samples[0]) - missing_mod.sum()))
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape, gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            tot_load = time.time() - start_load
            load_lst.append(tot_load)

            start_pred = time.time()
            preds = model(var_input, missing_mod)
            tot_pred = time.time() - start_pred
            pred_lst.append(tot_pred)

            start_opt = time.time()
            loss = - losses.DiceLoss()(preds, var_gt)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()

            num_steps += 1
            tot_opt = time.time() - start_opt
            opt_lst.append(tot_opt)

            start_gen = time.time()

        start_schedul = time.time()
        if not step_scheduler_batch:
            scheduler.step()
        tot_schedul = time.time() - start_schedul
        schedul_lst.append(tot_schedul)

        start_reload = time.time()
        print("[INFO]: Updating Dataset")
        p = p ** (2 / 3)
        dataset.update(p=p)
        print("[INFO]: Reloading dataset")
        train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  shuffle=True, pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)
        tot_reload = time.time() - start_reload
        reload_lst.append(tot_reload)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

    print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst)))
    print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst)))
    print('Mean SD reload {} -- {}'.format(np.mean(reload_lst), np.std(reload_lst)))
    print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst)))
    print('Mean SD opt {} --  {}'.format(np.mean(opt_lst), np.std(opt_lst)))
    print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst)))
    print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst), np.std(schedul_lst)))
Exemple #5
0
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          log_directory,
          device,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        log_directory (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the log directory.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=log_directory)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params["balance_samples"]["applied"],
        model_params["name"] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params['balance_samples']['type'])

    train_loader = DataLoader(dataset_train,
                              batch_size=training_params["batch_size"],
                              shuffle=shuffle_train,
                              pin_memory=True,
                              sampler=sampler_train,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions,
            training_params['balance_samples']['type'])

        val_loader = DataLoader(dataset_val,
                                batch_size=training_params["batch_size"],
                                shuffle=shuffle_val,
                                pin_memory=True,
                                sampler=sampler_val,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]]["input_metadata"][0])
            gif_dict["image_path"].append(random_metadata['input_filenames'])
            gif_dict["slice_id"].append(random_metadata['slice_index'])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        print("\nLoading pretrained model's weights: {}.")
        print("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        print("\nInitialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params["name"])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    print("\nScheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        contrast_list = []

    # Resume
    start_epoch = 1
    resume_path = os.path.join(log_directory, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=resume_path)
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    print("\nSelected Loss: {}".format(training_params["loss"]["name"]))
    print("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, log_directory)

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            if i == 0 and debugging:
                imed_visualize.save_tensorboard_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    is_three_dim=not model_params["is_2d"])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params["name"] == "HeMISUnet":
            # Increase the probability of a missing modality
            model_params["missing_probability"] **= model_params[
                "missing_probability_growth"]
            dataset_train.update(p=model_params["missing_probability"])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params["name"] == "HeMISUnet":
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params["name"] == "HeMISUnet" or \
                            ('film_layers' in model_params and any(model_params['film_layers'])):
                        metadata = get_metadata(batch["input_metadata"],
                                                model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch["input_metadata"][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                if i == 0 and debugging:
                    imed_visualize.save_tensorboard_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        is_three_dim=not model_params['is_2d'])

                if 'film_layers' in model_params and any(model_params['film_layers']) and debugging and \
                        epoch == num_epochs and i < int(len(dataset_val) / training_params["batch_size"]) + 1:
                    # Store the values of gammas and betas after the last epoch for each batch
                    gammas_dict, betas_dict, contrast_list = store_film_params(
                        gammas_dict, betas_dict, contrast_list,
                        batch['input_metadata'], model,
                        model_params["film_layers"], model_params["depth"])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            val_loss_total_avg = val_loss_total / num_steps
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            tqdm.write(msg)
            end_time = time.time()
            total_time = end_time - start_time
            tqdm.write("Epoch {} took {:.2f} seconds.".format(
                epoch, total_time))

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = os.path.join(log_directory, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    print(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = os.path.join(log_directory, "final_model.pt")
    torch.save(model, final_model_path)
    if 'film_layers' in model_params and any(
            model_params['film_layers']) and debugging:
        save_film_params(gammas_dict, betas_dict, contrast_list,
                         model_params["depth"], log_directory)

    # Save best model in log directory
    if os.path.isfile(resume_path):
        state = torch.load(resume_path)
        model_path = os.path.join(log_directory, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = os.path.join(
                log_directory, model_params["folder_name"],
                model_params["folder_name"] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples, best_model_path)
        except:
            # Save best model in model directory
            best_model_path = os.path.join(log_directory,
                                           model_params["folder_name"],
                                           model_params["folder_name"] + ".pt")
            torch.save(model, best_model_path)
            logger.warning(
                "Failed to save the model as '.onnx', saved it as '.pt': {}".
                format(best_model_path))

    # Save GIFs
    gif_folder = os.path.join(log_directory, "gifs")
    if n_gif > 0 and not os.path.isdir(gif_folder):
        os.makedirs(gif_folder)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split('/')[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split('/')[-1].split(
            ".nii.gz")[0].split(gif_dict["image_path"][i_gif].split('/')[-3] +
                                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = os.path.join(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(path_gif_out)

    writer.close()
    final_time = time.time()
    duration_time = final_time - begin_time
    print('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) +
          "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) +
          "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss
Exemple #6
0
def segment_volume(folder_model, fname_images, gpu_number=0, options=None):
    """Segment an image.
    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.
    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple
            images to segment, e.i., len(fname_images) > 1.
        gpu_number (int): Number representing gpu number if available.
        options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename
            (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS
            lesion classification)
            e.g., spinal cord centerline, used to crop the image prior to segment it if provided.
            The segmentation is not performed on the slices that are empty in this image.
    Returns:
        list: List of nibabel objects containing the soft segmentation(s), one per prediction class.
        list: List of target suffix associated with each prediction in `pred_list`

    """
    # Define device
    cuda_available = torch.cuda.is_available()
    device = torch.device("cpu") if not cuda_available else torch.device(
        "cuda:" + str(gpu_number))

    # Check if model folder exists and get filenames
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    context = imed_config_manager.ConfigurationManager(
        fname_model_metadata).get_config()

    postpro_list = [
        'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small'
    ]
    if options is not None and any(pp in options for pp in postpro_list):
        postpro = {}
        if 'binarize_prediction' in options and options['binarize_prediction']:
            postpro['binarize_prediction'] = {
                "thr": options['binarize_prediction']
            }
        if 'keep_largest' in options and options['keep_largest'] is not None:
            if options['keep_largest']:
                postpro['keep_largest'] = {}
            # Remove key in context if value set to 0
            elif 'keep_largest' in context['postprocessing']:
                del context['postprocessing']['keep_largest']
        if 'fill_holes' in options and options['fill_holes'] is not None:
            if options['fill_holes']:
                postpro['fill_holes'] = {}
            # Remove key in context if value set to 0
            elif 'fill_holes' in context['postprocessing']:
                del context['postprocessing']['fill_holes']
        if 'remove_small' in options and options['remove_small'] and \
                ('mm' in options['remove_small'][-1] or 'vox' in options['remove_small'][-1]):
            unit = 'mm3' if 'mm3' in options['remove_small'][-1] else 'vox'
            thr = [int(t.replace(unit, "")) for t in options['remove_small']]
            postpro['remove_small'] = {"unit": unit, "thr": thr}

        context['postprocessing'].update(postpro)

    # LOADER
    loader_params = context["loader_parameters"]
    slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']]
    metadata = {}
    fname_roi = None
    fname_prior = options['fname_prior'] if (options is not None) and (
        'fname_prior' in options) else None
    if fname_prior is not None:
        if 'roi_params' in loader_params and loader_params['roi_params'][
                'suffix'] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        # If ROI is not provided then force center cropping
        if fname_roi is None and 'ROICrop' in context["transformation"].keys():
            print(
                "\n WARNING: fname_roi has not been specified, then a cropping around the center of the image is "
                "performed instead of a cropping around a Region of Interest.")

            context["transformation"] = dict(
                (key, value) if key != 'ROICrop' else ('CenterCrop', value)
                for (key, value) in context["transformation"].items())

        if 'object_detection_params' in context and \
                context['object_detection_params']['object_detection_path'] is not None:
            imed_obj_detect.bounding_box_prior(
                fname_prior, metadata, slice_axis,
                context['object_detection_params']['safety_factor'])
            metadata = [metadata] * len(fname_images)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and 'filter_empty_mask' in loader_params[
            "slice_filter_params"]:
        print(
            "\nWARNING: fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params["slice_filter_params"]["filter_empty_mask"] = False

    filename_pairs = [(fname_images, None, fname_roi,
                       metadata if isinstance(metadata, list) else [metadata])]

    kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \
                not context['default_model']['is_2d']
    if kernel_3D:
        ds = imed_loader.MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context["Modified3DUNet"]["length_3D"],
            stride=context["Modified3DUNet"]["stride_3D"])
    else:
        ds = imed_loader.MRI2DSegmentationDataset(
            filename_pairs,
            slice_axis=slice_axis,
            cache=True,
            transform=tranform_lst,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **loader_params["slice_filter_params"]))
        ds.load_filenames()

    if kernel_3D:
        print("\nLoaded {} {} volumes of shape {}.".format(
            len(ds), loader_params['slice_axis'],
            context['Modified3DUNet']['length_3D']))
    else:
        print("\nLoaded {} {} slices.".format(len(ds),
                                              loader_params['slice_axis']))

    model_params = {}
    if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']:
        metadata_dict = joblib.load(
            os.path.join(folder_model, 'metadata_dict.joblib'))
        for idx in ds.indexes:
            for i in range(len(idx)):
                idx[i]['input_metadata'][0][context['FiLMedUnet']
                                            ['metadata']] = options['metadata']
                idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict

        ds = imed_film.normalize_metadata(ds, None, context["debugging"],
                                          context['FiLMedUnet']['metadata'])
        onehotencoder = joblib.load(
            os.path.join(folder_model, 'one_hot_encoder.joblib'))

        model_params.update({
            "name":
            'FiLMedUnet',
            "film_onehotencoder":
            onehotencoder,
            "n_metadata":
            len([ll for l in onehotencoder.categories_ for ll in l])
        })

    # Data Loader
    data_loader = DataLoader(
        ds,
        batch_size=context["training_parameters"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    # MODEL
    if fname_model.endswith('.pt'):
        model = torch.load(fname_model, map_location=device)
        # Inference time
        model.eval()

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, volume, weight_matrix = False, None, None
    for i_batch, batch in enumerate(data_loader):
        with torch.no_grad():
            img = imed_utils.cuda(batch['input'],
                                  cuda_available=cuda_available)

            if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \
                    ('HeMISUnet' in context and context['HeMISUnet']['applied']):
                metadata = imed_training.get_metadata(batch["input_metadata"],
                                                      model_params)
                preds = model(img, metadata)

            else:
                preds = model(img) if fname_model.endswith(
                    '.pt') else onnx_inference(fname_model, img)

            preds = preds.cpu()

        # Set datatype to gt since prediction should be processed the same way as gt
        for b in batch['input_metadata']:
            for modality in b:
                modality['data_type'] = 'gt'

        # Reconstruct 3D object
        for i_slice in range(len(preds)):
            if "bounding_box" in batch['input_metadata'][i_slice][0]:
                imed_obj_detect.adjust_undo_transforms(
                    undo_transforms.transforms, batch, i_slice)

            batch['gt_metadata'] = [[metadata[0]] * preds.shape[1]
                                    for metadata in batch['input_metadata']]
            if kernel_3D:
                preds_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix)
                preds_list = [np.array(preds_undo)]
            else:
                # undo transformations
                preds_i_undo, metadata_idx = undo_transforms(
                    preds[i_slice],
                    batch["input_metadata"][i_slice],
                    data_type='gt')

                # Add new segmented slice to preds_list
                preds_list.append(np.array(preds_i_undo))
                # Store the slice index of preds_i_undo in the original 3D image
                slice_idx_list.append(
                    int(batch['input_metadata'][i_slice][0]['slice_index']))

            # If last batch and last sample of this batch, then reconstruct 3D object
            if (i_batch == len(data_loader) - 1
                    and i_slice == len(batch['gt']) - 1) or last_sample_bool:
                pred_nib = pred_to_nib(
                    data_lst=preds_list,
                    fname_ref=fname_images[0],
                    fname_out=None,
                    z_lst=slice_idx_list,
                    slice_axis=slice_axis,
                    kernel_dim='3d' if kernel_3D else '2d',
                    debug=False,
                    bin_thr=-1,
                    postprocessing=context['postprocessing'])

                pred_list = split_classes(pred_nib)
                target_list = context['loader_parameters']['target_suffix']

    return pred_list, target_list
def test_unet_time(train_lst, target_lst, config):
    cuda_available, device = imed_utils.define_device(GPU_NUMBER)

    loader_params = {
        "data_list": train_lst,
        "dataset_type": "training",
        "requires_undo": False,
        "bids_path": PATH_BIDS,
        "target_suffix": target_lst,
        "slice_filter_params": {
            "filter_empty_mask": False,
            "filter_empty_input": True
        },
        "slice_axis": "axial"
    }
    # Update loader_params with config
    loader_params.update(config)
    # Get Training dataset
    ds_train = imed_loader.load_dataset(**loader_params)

    # Loader
    train_loader = DataLoader(ds_train,
                              batch_size=1 if config["model_params"]["name"]
                              == "UNet3D" else BATCH_SIZE,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    # MODEL
    model_params = loader_params["model_params"]
    model_params.update(MODEL_DEFAULT)
    # Get in_channel from contrast_lst
    if loader_params["multichannel"]:
        model_params["in_channel"] = len(
            loader_params["contrast_params"]["contrast_lst"])
    else:
        model_params["in_channel"] = 1
    # Get out_channel from target_suffix
    model_params["out_channel"] = len(loader_params["target_suffix"])
    model_class = getattr(imed_models, model_params["name"])
    model = model_class(**model_params)

    print("Training {}".format(model_params["name"]))
    if cuda_available:
        model.cuda()

    step_scheduler_batch = False
    # TODO: Add optim in pytest
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

    # TODO: add to pytest
    loss_fct = imed_losses.DiceLoss()

    load_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], []
    for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"):
        start_time = time.time()

        start_init = time.time()

        model.train()

        tot_init = time.time() - start_init
        init_lst.append(tot_init)

        num_steps = 0
        start_gen = 0
        for i, batch in enumerate(train_loader):
            if i > 0:
                tot_gen = time.time() - start_gen
                gen_lst.append(tot_gen)

            start_load = time.time()
            input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            tot_load = time.time() - start_load
            load_lst.append(tot_load)

            start_pred = time.time()
            preds = model(input_samples)
            tot_pred = time.time() - start_pred
            pred_lst.append(tot_pred)

            start_opt = time.time()
            loss = loss_fct(preds, gt_samples)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()

            num_steps += 1
            tot_opt = time.time() - start_opt
            opt_lst.append(tot_opt)

            start_gen = time.time()

        start_schedul = time.time()
        if not step_scheduler_batch:
            scheduler.step()
        tot_schedul = time.time() - start_schedul
        schedul_lst.append(tot_schedul)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

    print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst)))
    print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst)))
    print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst)))
    print('Mean SDopt {} --  {}'.format(np.mean(opt_lst), np.std(opt_lst)))
    print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst)))
    print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst),
                                              np.std(schedul_lst)))
Exemple #8
0
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None,
                  postprocessing=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.
        postprocessing (dict): Indicates postprocessing steps.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list, filenames = [], [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        # 2 * model_params["depth"] + 2 is the number of FiLM layers. 1 is added since the range starts at one.
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        metadata_values_lst = []

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "Modified3DUNet" and model_params[
                "attention"] and ofolder:
            imed_visualize.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        if 'film_layers' in model_params and any(model_params['film_layers']):
            # Store the values of gammas and betas after the last epoch for each batch
            gammas_dict, betas_dict, metadata_values_lst = store_film_params(
                gammas_dict, betas_dict, metadata_values_lst,
                batch['input_metadata'], model, model_params["film_layers"],
                model_params["depth"], model_params['metadata'])

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if model_params["is_2d"]:
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = list(filter(None,
                                        metadata_idx[0]['gt_filenames']))[0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  Path(fname_tmp).name)
                        fname_pred = fname_pred.rsplit("_",
                                                       1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(filenames)
                    gt_npy_list.append(gt)

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(np.stack(pred_tmp_lst, -1),
                        #                              False,
                        #                              fname_tmp,
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref
                filenames = metadata_idx[0]['gt_filenames']

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_inference.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(metadata[0]['gt_filenames'])
                    gt_npy_list.append(gt)
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(pred_undo,
                        #                              False,
                        #                              batch['input_metadata'][smp_idx][0]['input_filenames'],
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              slice_axis)

    if 'film_layers' in model_params and any(model_params['film_layers']):
        save_film_params(gammas_dict, betas_dict, metadata_values_lst,
                         model_params["depth"],
                         ofolder.replace("pred_masks", ""))
    return preds_npy_list, gt_npy_list
Exemple #9
0
def test_hdf5(download_data_testing_test_files, loader_parameters):
    print('[INFO]: Starting test ... \n')

    bids_df = imed_loader_utils.BidsDataframe(loader_parameters,
                                              __tmp_dir__,
                                              derivatives=True)

    contrast_params = loader_parameters["contrast_params"]
    target_suffix = loader_parameters["target_suffix"]
    roi_params = loader_parameters["roi_params"]

    train_lst = ['sub-unf01']

    training_transform_dict = {
        "Resample": {
            "wspace": 0.75,
            "hspace": 0.75
        },
        "CenterCrop": {
            "size": [48, 48]
        },
        "NumpyToTensor": {}
    }
    transform_lst, _ = imed_transforms.prepare_transforms(
        training_transform_dict)

    bids_to_hdf5 = imed_adaptative.BIDStoHDF5(
        bids_df=bids_df,
        subject_file_lst=train_lst,
        path_hdf5=os.path.join(__data_testing_dir__, 'mytestfile.hdf5'),
        target_suffix=target_suffix,
        roi_params=roi_params,
        contrast_lst=contrast_params["contrast_lst"],
        metadata_choice="contrast",
        transform=transform_lst,
        contrast_balance={},
        slice_axis=2,
        slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True,
                                                      filter_empty_mask=True))

    # Checking architecture
    def print_attrs(name, obj):
        print("\nName of the object: {}".format(name))
        print("Type: {}".format(type(obj)))
        print("Including the following attributes:")
        for key, val in obj.attrs.items():
            print("    %s: %s" % (key, val))

    print('\n[INFO]: HDF5 architecture:')
    with h5py.File(bids_to_hdf5.path_hdf5, "a") as hdf5_file:
        hdf5_file.visititems(print_attrs)
        print('\n[INFO]: HDF5 file successfully generated.')
        print('[INFO]: Generating dataframe ...\n')

        df = imed_adaptative.Dataframe(hdf5_file=hdf5_file,
                                       contrasts=['T1w', 'T2w', 'T2star'],
                                       path=os.path.join(
                                           __data_testing_dir__, 'hdf5.csv'),
                                       target_suffix=['T1w', 'T2w', 'T2star'],
                                       roi_suffix=['T1w', 'T2w', 'T2star'],
                                       dim=2,
                                       filter_slices=True)

        print(df.df)

        print('\n[INFO]: Dataframe successfully generated. ')
        print('[INFO]: Creating dataset ...\n')

        model_params = {
            "name": "HeMISUnet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "missing_probability": 0.00001,
            "missing_probability_growth": 0.9,
            "contrasts": ["T1w", "T2w"],
            "ram": False,
            "path_hdf5": os.path.join(__data_testing_dir__, 'mytestfile.hdf5'),
            "csv_path": os.path.join(__data_testing_dir__, 'hdf5.csv'),
            "target_lst": ["T2w"],
            "roi_lst": ["T2w"]
        }

        dataset = imed_adaptative.HDF5Dataset(
            bids_df=bids_df,
            subject_file_lst=train_lst,
            target_suffix=target_suffix,
            slice_axis=2,
            model_params=model_params,
            contrast_params=contrast_params,
            transform=transform_lst,
            metadata_choice=False,
            dim=2,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                filter_empty_input=True, filter_empty_mask=True),
            roi_params=roi_params)

        dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
        print("Dataset RAM status:")
        print(dataset.status)
        print("In memory Dataframe:")
        print(dataset.dataframe)
        print('\n[INFO]: Test passed successfully. ')

        print("\n[INFO]: Starting loader test ...")

        device = torch.device(
            "cuda:" + str(GPU_ID) if torch.cuda.is_available() else "cpu")
        cuda_available = torch.cuda.is_available()
        if cuda_available:
            torch.cuda.set_device(device)
            print("Using GPU ID {}".format(device))

        train_loader = DataLoader(dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=False,
                                  pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)

        for i, batch in enumerate(train_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape,
                                          gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            break
        print(
            "Congrats your dataloader works! You can go home now and get a beer."
        )
        return 0
Exemple #10
0
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list = [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] in ["HeMISUnet", "FiLMedUnet"]:
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "UNet3D" and model_params[
                "attention"] and ofolder:
            imed_utils.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if not model_params["name"].endswith('3D'):
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = metadata_idx[0]['gt_filenames'][0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_tmp.split('/')[-1])
                        fname_pred = fname_pred.rsplit(
                            testing_params['target_suffix'][0],
                            1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=testing_params["binarize_prediction"])
                    # TODO: Adapt to multilabel
                    output_data = output_nii.get_fdata()[:, :, :, 0]
                    preds_npy_list.append(output_data)

                    gt_npy_list.append(nib.load(fname_tmp).get_fdata())

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            np.stack(pred_tmp_lst, -1),
                            testing_params["binarize_prediction"] > 0,
                            fname_tmp,
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_utils.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=testing_params["binarize_prediction"])
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt_lst = []
                    for gt in metadata[0]['gt_filenames']:
                        # For multi-label, if all labels are not in every image
                        if gt is not None:
                            gt_lst.append(nib.load(gt).get_fdata())
                        else:
                            gt_lst.append(np.zeros(gt_lst[0].shape))

                    gt_npy_list.append(np.array(gt_lst))
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            pred_undo,
                            testing_params['binarize_prediction'] > 0,
                            batch['input_metadata'][smp_idx][0]
                            ['input_filenames'],
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            slice_axis)

    return preds_npy_list, gt_npy_list
def test_unet_time(download_data_testing_test_files, train_lst, target_lst, config):
    cuda_available, device = imed_utils.define_device(GPU_ID)

    loader_params = {
        "data_list": train_lst,
        "dataset_type": "training",
        "requires_undo": False,
        "path_data": [__data_testing_dir__],
        "target_suffix": target_lst,
        "extensions": [".nii.gz"],
        "slice_filter_params": {"filter_empty_mask": False, "filter_empty_input": True},
        "patch_filter_params": {"filter_empty_mask": False, "filter_empty_input": False},
        "slice_axis": "axial"
    }
    # Update loader_params with config
    loader_params.update(config)
    # Get Training dataset
    bids_df = BidsDataframe(loader_params, __tmp_dir__, derivatives=True)
    ds_train = imed_loader.load_dataset(bids_df, **loader_params)

    # Loader
    train_loader = DataLoader(ds_train,
                              batch_size=1 if config["model_params"]["name"] == "Modified3DUNet" else BATCH_SIZE,
                              shuffle=True, pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    # MODEL
    model_params = loader_params["model_params"]
    model_params.update(MODEL_DEFAULT)
    # Get in_channel from contrast_lst
    if loader_params["multichannel"]:
        model_params["in_channel"] = len(loader_params["contrast_params"]["contrast_lst"])
    else:
        model_params["in_channel"] = 1
    # Get out_channel from target_suffix
    model_params["out_channel"] = len(loader_params["target_suffix"])
    model_class = getattr(imed_models, model_params["name"])
    model = model_class(**model_params)

    logger.debug(f"Training {model_params['name']}")
    if cuda_available:
        model.cuda()

    step_scheduler_batch = False
    # TODO: Add optim in pytest
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

    # TODO: add to pytest
    loss_fct = imed_losses.DiceLoss()

    load_lst, pred_lst, opt_lst, schedule_lst, init_lst, gen_lst = [], [], [], [], [], []
    for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"):
        start_time = time.time()

        start_init = time.time()

        model.train()

        tot_init = time.time() - start_init
        init_lst.append(tot_init)

        num_steps = 0
        start_gen = 0
        for i, batch in enumerate(train_loader):
            if i > 0:
                tot_gen = time.time() - start_gen
                gen_lst.append(tot_gen)

            start_load = time.time()
            input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True)

            tot_load = time.time() - start_load
            load_lst.append(tot_load)

            start_pred = time.time()

            if 'film_layers' in model_params:
                preds = model(input_samples, [[0, 1]])
            else:
                preds = model(input_samples)
            tot_pred = time.time() - start_pred
            pred_lst.append(tot_pred)

            start_opt = time.time()
            loss = loss_fct(preds, gt_samples)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()

            num_steps += 1
            tot_opt = time.time() - start_opt
            opt_lst.append(tot_opt)

            start_gen = time.time()

        start_schedule = time.time()
        if not step_scheduler_batch:
            scheduler.step()
        tot_schedule = time.time() - start_schedule
        schedule_lst.append(tot_schedule)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

    logger.info(f"Mean SD init {np.mean(init_lst)} -- {np.std(init_lst)}")
    logger.info(f"Mean SD load {np.mean(load_lst)} -- {np.std(load_lst)}")
    logger.info(f"Mean SD pred {np.mean(pred_lst)} -- {np.std(pred_lst)}")
    logger.info(f"Mean SDopt {np.mean(opt_lst)} --  {np.std(opt_lst)}")
    logger.info(f"Mean SD gen {np.mean(gen_lst)} -- {np.std(gen_lst)}")
    logger.info(f"Mean SD scheduler {np.mean(schedule_lst)} -- {np.std(schedule_lst)}")
Exemple #12
0
def train(model_params,
          dataset_train,
          dataset_val,
          training_params,
          path_output,
          device,
          wandb_params=None,
          cuda_available=True,
          metric_fns=None,
          n_gif=0,
          resume_training=False,
          debugging=False):
    """Main command to train the network.

    Args:
        model_params (dict): Model's parameters.
        dataset_train (imed_loader): Training dataset.
        dataset_val (imed_loader): Validation dataset.
        training_params (dict):
        path_output (str): Folder where log files, best and final models are saved.
        device (str): Indicates the CPU or GPU ID.
        cuda_available (bool): If True, CUDA is available.
        metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the output path.
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the path_output) for resume
                                training. This training state is saved everytime a new best model is saved in the log
                                directory.
        debugging (bool): If True, extended verbosity and intermediate outputs.

    Returns:
        float, float, float, float: best_training_dice, best_training_loss, best_validation_dice,
            best_validation_loss.
    """
    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=path_output)

    # Enable wandb tracking  if the required params are found in the config file and the api key is correct
    wandb_tracking = imed_utils.initialize_wandb(wandb_params)

    if wandb_tracking:
        # Collect all hyperparameters into a dictionary
        cfg = {**training_params, **model_params}

        # Get the actual project, group, and run names if they exist, else choose the temporary names as default
        project_name = wandb_params.get(WandbKW.PROJECT_NAME, "temp_project")
        group_name = wandb_params.get(WandbKW.GROUP_NAME, "temp_group")
        run_name = wandb_params.get(WandbKW.RUN_NAME, "temp_run")

        if project_name == "temp_project" or group_name == "temp_group" or run_name == "temp_run":
            logger.info(
                "{PROJECT/GROUP/RUN} name not found, initializing as {'temp_project'/'temp_group'/'temp_run'}"
            )

        # Initialize WandB with metrics and hyperparameters
        wandb.init(project=project_name,
                   group=group_name,
                   name=run_name,
                   config=cfg)

    # BALANCE SAMPLES AND PYTORCH LOADER
    conditions = all([
        training_params[TrainingParamsKW.BALANCE_SAMPLES]
        [BalanceSamplesKW.APPLIED], model_params[ModelParamsKW.NAME] != "HeMIS"
    ])
    sampler_train, shuffle_train = get_sampler(
        dataset_train, conditions, training_params[
            TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

    train_loader = DataLoader(
        dataset_train,
        batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
        shuffle=shuffle_train,
        pin_memory=True,
        sampler=sampler_train,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    gif_dict = {"image_path": [], "slice_id": [], "gif": []}
    if dataset_val:
        sampler_val, shuffle_val = get_sampler(
            dataset_val, conditions, training_params[
                TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE])

        val_loader = DataLoader(
            dataset_val,
            batch_size=training_params[TrainingParamsKW.BATCH_SIZE],
            shuffle=shuffle_val,
            pin_memory=True,
            sampler=sampler_val,
            collate_fn=imed_loader_utils.imed_collate,
            num_workers=0)

        # Init GIF
        if n_gif > 0:
            indexes_gif = random.sample(range(len(dataset_val)), n_gif)
        for i_gif in range(n_gif):
            random_metadata = dict(
                dataset_val[indexes_gif[i_gif]][MetadataKW.INPUT_METADATA][0])
            gif_dict["image_path"].append(
                random_metadata[MetadataKW.INPUT_FILENAMES])
            gif_dict["slice_id"].append(
                random_metadata[MetadataKW.SLICE_INDEX])
            gif_obj = imed_visualize.AnimatedGif(
                size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape)
            gif_dict["gif"].append(copy.copy(gif_obj))

    # GET MODEL
    if training_params["transfer_learning"]["retrain_model"]:
        logger.info("Loading pretrained model's weights: {}.")
        logger.info("\tFreezing the {}% first layers.".format(
            100 -
            training_params["transfer_learning"]['retrain_fraction'] * 100.))
        old_model_path = training_params["transfer_learning"]["retrain_model"]
        fraction = training_params["transfer_learning"]['retrain_fraction']
        if 'reset' in training_params["transfer_learning"]:
            reset = training_params["transfer_learning"]['reset']
        else:
            reset = True
        # Freeze first layers and reset last layers
        model = imed_models.set_model_for_retrain(old_model_path,
                                                  retrain_fraction=fraction,
                                                  map_location=device,
                                                  reset=reset)
    else:
        logger.info("Initialising model's weights from scratch.")
        model_class = getattr(imed_models, model_params[ModelParamsKW.NAME])
        model = model_class(**model_params)
    if cuda_available:
        model.cuda()

    num_epochs = training_params["training_time"]["num_epochs"]

    # OPTIMIZER
    initial_lr = training_params["scheduler"]["initial_lr"]
    # filter out the parameters you are going to fine-tuning
    params_to_opt = filter(lambda p: p.requires_grad, model.parameters())
    # Using Adam
    optimizer = optim.Adam(params_to_opt, lr=initial_lr)
    scheduler, step_scheduler_batch = get_scheduler(
        copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer,
        num_epochs)
    logger.info("Scheduler parameters: {}".format(
        training_params["scheduler"]["lr_scheduler"]))

    # Only call wandb methods if required params are found in the config file
    if wandb_tracking:
        # Logs gradients (at every log_freq steps) to the dashboard.
        wandb.watch(model,
                    log="gradients",
                    log_freq=wandb_params["log_grads_every"])

    # Resume
    start_epoch = 1
    resume_path = Path(path_output, "checkpoint.pth.tar")
    if resume_training:
        model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint(
            model=model,
            optimizer=optimizer,
            gif_dict=gif_dict,
            scheduler=scheduler,
            fname=str(resume_path))
        # Individually transfer the optimizer parts
        # TODO: check if following lines are needed
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    # LOSS
    logger.info("Selected Loss: {}".format(training_params["loss"]["name"]))
    logger.info("\twith the parameters: {}".format([
        training_params["loss"][k] for k in training_params["loss"]
        if k != "name"
    ]))
    loss_fct = get_loss_function(copy.copy(training_params["loss"]))
    loss_dice_fct = imed_losses.DiceLoss(
    )  # For comparison when another loss is used

    # INIT TRAINING VARIABLES
    best_training_dice, best_training_loss = float("inf"), float("inf")
    best_validation_loss, best_validation_dice = float("inf"), float("inf")
    patience_count = 0
    begin_time = time.time()

    # EPOCH LOOP
    for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch):
        epoch = epoch + start_epoch
        start_time = time.time()

        lr = scheduler.get_last_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)
        if wandb_tracking:
            wandb.log({"learning_rate": lr})

        # Training loop -----------------------------------------------------------
        model.train()
        train_loss_total, train_dice_loss_total = 0.0, 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            # GET SAMPLES
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # MIXUP
            if training_params["mixup_alpha"]:
                input_samples, gt_samples = imed_mixup.mixup(
                    input_samples, gt_samples, training_params["mixup_alpha"],
                    debugging and epoch == 1, path_output)

            # RUN MODEL
            if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                    (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                metadata = get_metadata(batch[MetadataKW.INPUT_METADATA],
                                        model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

            # LOSS
            loss = loss_fct(preds, gt_samples)
            train_loss_total += loss.item()
            train_dice_loss_total += loss_dice_fct(preds, gt_samples).item()

            # UPDATE OPTIMIZER
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()
            num_steps += 1

            # Save image at every 50th step if debugging is true
            if i % 50 == 0 and debugging:
                imed_visualize.save_img(
                    writer,
                    epoch,
                    "Train",
                    input_samples,
                    gt_samples,
                    preds,
                    wandb_tracking=wandb_tracking,
                    is_three_dim=not model_params[ModelParamsKW.IS_2D])

        if not step_scheduler_batch:
            scheduler.step()

        # TRAINING LOSS
        train_loss_total_avg = train_loss_total / num_steps
        msg = "Epoch {} training loss: {:.4f}.".format(epoch,
                                                       train_loss_total_avg)
        train_dice_loss_total_avg = train_dice_loss_total / num_steps
        if training_params["loss"]["name"] != "DiceLoss":
            msg += "\tDice training loss: {:.4f}.".format(
                train_dice_loss_total_avg)
        logger.info(msg)
        tqdm.write(msg)

        # CURRICULUM LEARNING
        if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
            # Increase the probability of a missing modality
            model_params[ModelParamsKW.MISSING_PROBABILITY] **= model_params[
                ModelParamsKW.MISSING_PROBABILITY_GROWTH]
            dataset_train.update(
                p=model_params[ModelParamsKW.MISSING_PROBABILITY])

        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total, val_dice_loss_total = 0.0, 0.0
        num_steps = 0
        metric_mgr = imed_metrics.MetricManager(metric_fns)
        if dataset_val:
            for i, batch in enumerate(val_loader):
                with torch.no_grad():
                    # GET SAMPLES
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
                        input_samples = imed_utils.cuda(
                            imed_utils.unstack_tensors(batch["input"]),
                            cuda_available)
                    else:
                        input_samples = imed_utils.cuda(
                            batch["input"], cuda_available)
                    gt_samples = imed_utils.cuda(batch["gt"],
                                                 cuda_available,
                                                 non_blocking=True)

                    # RUN MODEL
                    if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \
                            (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])):
                        metadata = get_metadata(
                            batch[MetadataKW.INPUT_METADATA], model_params)
                        preds = model(input_samples, metadata)
                    else:
                        preds = model(input_samples)

                    # LOSS
                    loss = loss_fct(preds, gt_samples)
                    val_loss_total += loss.item()
                    val_dice_loss_total += loss_dice_fct(preds,
                                                         gt_samples).item()

                    # Add frame to GIF
                    for i_ in range(len(input_samples)):
                        im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \
                                      batch[MetadataKW.INPUT_METADATA][i_][0]
                        for i_gif in range(n_gif):
                            if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \
                                    gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'):
                                overlap = imed_visualize.overlap_im_seg(im, pr)
                                gif_dict["gif"][i_gif].add(overlap,
                                                           label=str(epoch))

                num_steps += 1

                # METRICS COMPUTATION
                gt_npy = gt_samples.cpu().numpy()
                preds_npy = preds.data.cpu().numpy()
                metric_mgr(preds_npy, gt_npy)

                # Save image at every 10th step if debugging is true
                if i % 50 == 0 and debugging:
                    imed_visualize.save_img(
                        writer,
                        epoch,
                        "Validation",
                        input_samples,
                        gt_samples,
                        preds,
                        wandb_tracking=wandb_tracking,
                        is_three_dim=not model_params[ModelParamsKW.IS_2D])

            # METRICS COMPUTATION FOR CURRENT EPOCH
            val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None
            metrics_dict = metric_mgr.get_results()
            metric_mgr.reset()
            val_loss_total_avg = val_loss_total / num_steps
            # log losses on Tensorboard by default
            writer.add_scalars('Validation/Metrics', metrics_dict, epoch)
            writer.add_scalars(
                'losses', {
                    'train_loss': train_loss_total_avg,
                    'val_loss': val_loss_total_avg,
                }, epoch)
            # log on wandb if the corresponding dictionary is provided
            if wandb_tracking:
                wandb.log({"validation-metrics": metrics_dict})
                wandb.log({
                    "losses": {
                        'train_loss': train_loss_total_avg,
                        'val_loss': val_loss_total_avg,
                    }
                })
            msg = "Epoch {} validation loss: {:.4f}.".format(
                epoch, val_loss_total_avg)
            val_dice_loss_total_avg = val_dice_loss_total / num_steps
            if training_params["loss"]["name"] != "DiceLoss":
                msg += "\tDice validation loss: {:.4f}.".format(
                    val_dice_loss_total_avg)
            logger.info(msg)
            end_time = time.time()
            total_time = end_time - start_time
            msg_epoch = "Epoch {} took {:.2f} seconds.".format(
                epoch, total_time)
            logger.info(msg_epoch)

            # UPDATE BEST RESULTS
            if val_loss_total_avg < best_validation_loss:
                # Save checkpoint
                state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'gif_dict': gif_dict,
                    'scheduler': scheduler,
                    'patience_count': patience_count,
                    'validation_loss': val_loss_total_avg
                }
                torch.save(state, resume_path)

                # Save best model file
                model_path = Path(path_output, "best_model.pt")
                torch.save(model, model_path)

                # Update best scores
                best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg
                best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg

            # EARLY STOPPING
            if epoch > 1:
                val_diff = (val_loss_total_avg_old -
                            val_loss_total_avg) * 100 / abs(val_loss_total_avg)
                if val_diff < training_params["training_time"][
                        "early_stopping_epsilon"]:
                    patience_count += 1
                if patience_count >= training_params["training_time"][
                        "early_stopping_patience"]:
                    logger.info(
                        "Stopping training due to {} epochs without improvements"
                        .format(patience_count))
                    break

    # Save final model
    final_model_path = Path(path_output, "final_model.pt")
    torch.save(model, final_model_path)

    # Save best model in output path
    if resume_path.is_file():
        state = torch.load(resume_path)
        model_path = Path(path_output, "best_model.pt")
        model.load_state_dict(state['state_dict'])
        torch.save(model, model_path)
        # Save best model as ONNX in the model directory
        try:
            # Convert best model to ONNX and save it in model directory
            best_model_path = Path(
                path_output, model_params[ModelParamsKW.FOLDER_NAME],
                model_params[ModelParamsKW.FOLDER_NAME] + ".onnx")
            imed_utils.save_onnx_model(model, input_samples,
                                       str(best_model_path))
            logger.info(f"Model saved as '.onnx': {best_model_path}")
        except Exception as e:
            logger.warning(f"Failed to save the model as '.onnx': {e}")

        # Save best model as PT in the model directory
        best_model_path = Path(path_output,
                               model_params[ModelParamsKW.FOLDER_NAME],
                               model_params[ModelParamsKW.FOLDER_NAME] + ".pt")
        torch.save(model, best_model_path)
        logger.info(f"Model saved as '.pt': {best_model_path}")

    # Save GIFs
    gif_folder = Path(path_output, "gifs")
    if n_gif > 0 and not gif_folder.is_dir():
        gif_folder.mkdir(parents=True)
    for i_gif in range(n_gif):
        fname_out = gif_dict["image_path"][i_gif].split(os.sep)[-3] + "__"
        fname_out += gif_dict["image_path"][i_gif].split(
            os.sep)[-1].split(".nii.gz")[0].split(
                gif_dict["image_path"][i_gif].split(os.sep)[-3] +
                "_")[1] + "__"
        fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif"
        path_gif_out = Path(gif_folder, fname_out)
        gif_dict["gif"][i_gif].save(str(path_gif_out))

    writer.close()
    wandb.finish()
    final_time = time.time()
    duration_time = final_time - begin_time
    logger.info('begin ' +
                time.strftime('%H:%M:%S', time.localtime(begin_time)) +
                "| End " +
                time.strftime('%H:%M:%S', time.localtime(final_time)) +
                "| duration " + str(datetime.timedelta(seconds=duration_time)))

    return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss