Beispiel #1
0
def load_dataset(data_list,
                 bids_path,
                 transforms_params,
                 model_params,
                 target_suffix,
                 roi_params,
                 contrast_params,
                 slice_filter_params,
                 slice_axis,
                 multichannel,
                 dataset_type="training",
                 requires_undo=False,
                 metadata_type=None,
                 object_detection_params=None,
                 soft_gt=False,
                 device=None,
                 cuda_available=None,
                 **kwargs):
    """Get loader appropriate loader according to model type. Available loaders are Bids3DDataset for 3D data,
    BidsDataset for 2D data and HDF5Dataset for HeMIS.

    Args:
        data_list (list): Subject names list.
        bids_path (str): Path to the BIDS dataset.
        transforms_params (dict): Dictionary containing transformations for "training", "validation", "testing" (keys),
            eg output of imed_transforms.get_subdatasets_transforms.
        model_params (dict): Dictionary containing model parameters.
        target_suffix (list of str): List of suffixes for target masks.
        roi_params (dict): Contains ROI related parameters.
        contrast_params (dict): Contains image contrasts related parameters.
        slice_filter_params (dict): Contains slice_filter parameters, see :doc:`configuration_file` for more details.
        slice_axis (string): Choice between "axial", "sagittal", "coronal" ; controls the axis used to extract the 2D
            data.
        multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
            contrast is processed individually (ie different sample / tensor).
        metadata_type (str): Choice between None, "mri_params", "contrasts".
        dataset_type (str): Choice between "training", "validation" or "testing".
        requires_undo (bool): If True, the transformations without undo_transform will be discarded.
        object_detection_params (dict): Object dection parameters.
        soft_gt (bool): If True, ground truths will be converted to float32, otherwise to uint8 and binarized
            (to save memory).
    Returns:
        BidsDataset

    Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
    slice_filter_params and object_detection_params see :doc:`configuration_file`.
    """
    # Compose transforms
    tranform_lst, _ = imed_transforms.prepare_transforms(
        copy.deepcopy(transforms_params), requires_undo)

    # If ROICrop is not part of the transforms, then enforce no slice filtering based on ROI data.
    if 'ROICrop' not in transforms_params:
        roi_params["slice_filter_roi"] = None

    if model_params["name"] == "Modified3DUNet" or ('is_2d' in model_params and
                                                    not model_params['is_2d']):
        dataset = Bids3DDataset(
            bids_path,
            subject_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            model_params=model_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)

    elif model_params["name"] == "HeMISUnet":
        dataset = imed_adaptative.HDF5Dataset(
            root_dir=bids_path,
            subject_lst=data_list,
            model_params=model_params,
            contrast_params=contrast_params,
            target_suffix=target_suffix,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            metadata_choice=metadata_type,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **slice_filter_params,
                device=device,
                cuda_available=cuda_available),
            roi_params=roi_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)
    else:
        # Task selection
        task = imed_utils.get_task(model_params["name"])

        dataset = BidsDataset(bids_path,
                              subject_lst=data_list,
                              target_suffix=target_suffix,
                              roi_params=roi_params,
                              contrast_params=contrast_params,
                              metadata_choice=metadata_type,
                              slice_axis=imed_utils.AXIS_DCT[slice_axis],
                              transform=tranform_lst,
                              multichannel=multichannel,
                              slice_filter_fn=imed_loader_utils.SliceFilter(
                                  **slice_filter_params,
                                  device=device,
                                  cuda_available=cuda_available),
                              soft_gt=soft_gt,
                              object_detection_params=object_detection_params,
                              task=task)
        dataset.load_filenames()

    if model_params["name"] != "Modified3DUNet":
        print("Loaded {} {} slices for the {} set.".format(
            len(dataset), slice_axis, dataset_type))
    else:
        print("Loaded {} volumes of size {} for the {} set.".format(
            len(dataset), slice_axis, dataset_type))

    return dataset
Beispiel #2
0
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list = [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] in ["HeMISUnet", "FiLMedUnet"]:
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "UNet3D" and model_params[
                "attention"] and ofolder:
            imed_utils.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if not model_params["name"].endswith('3D'):
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = metadata_idx[0]['gt_filenames'][0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_tmp.split('/')[-1])
                        fname_pred = fname_pred.rsplit(
                            testing_params['target_suffix'][0],
                            1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=testing_params["binarize_prediction"])
                    # TODO: Adapt to multilabel
                    output_data = output_nii.get_fdata()[:, :, :, 0]
                    preds_npy_list.append(output_data)

                    gt_npy_list.append(nib.load(fname_tmp).get_fdata())

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            np.stack(pred_tmp_lst, -1),
                            testing_params["binarize_prediction"] > 0,
                            fname_tmp,
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_utils.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=testing_params["binarize_prediction"])
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt_lst = []
                    for gt in metadata[0]['gt_filenames']:
                        # For multi-label, if all labels are not in every image
                        if gt is not None:
                            gt_lst.append(nib.load(gt).get_fdata())
                        else:
                            gt_lst.append(np.zeros(gt_lst[0].shape))

                    gt_npy_list.append(np.array(gt_lst))
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            pred_undo,
                            testing_params['binarize_prediction'] > 0,
                            batch['input_metadata'][smp_idx][0]
                            ['input_filenames'],
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            slice_axis)

    return preds_npy_list, gt_npy_list
Beispiel #3
0
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None,
                  postprocessing=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.
        postprocessing (dict): Indicates postprocessing steps.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list, filenames = [], [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        # 2 * model_params["depth"] + 2 is the number of FiLM layers. 1 is added since the range starts at one.
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        metadata_values_lst = []

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "Modified3DUNet" and model_params[
                "attention"] and ofolder:
            imed_visualize.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        if 'film_layers' in model_params and any(model_params['film_layers']):
            # Store the values of gammas and betas after the last epoch for each batch
            gammas_dict, betas_dict, metadata_values_lst = store_film_params(
                gammas_dict, betas_dict, metadata_values_lst,
                batch['input_metadata'], model, model_params["film_layers"],
                model_params["depth"], model_params['metadata'])

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if model_params["is_2d"]:
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = list(filter(None,
                                        metadata_idx[0]['gt_filenames']))[0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  Path(fname_tmp).name)
                        fname_pred = fname_pred.rsplit("_",
                                                       1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(filenames)
                    gt_npy_list.append(gt)

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(np.stack(pred_tmp_lst, -1),
                        #                              False,
                        #                              fname_tmp,
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref
                filenames = metadata_idx[0]['gt_filenames']

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_inference.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(metadata[0]['gt_filenames'])
                    gt_npy_list.append(gt)
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(pred_undo,
                        #                              False,
                        #                              batch['input_metadata'][smp_idx][0]['input_filenames'],
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              slice_axis)

    if 'film_layers' in model_params and any(model_params['film_layers']):
        save_film_params(gammas_dict, betas_dict, metadata_values_lst,
                         model_params["depth"],
                         ofolder.replace("pred_masks", ""))
    return preds_npy_list, gt_npy_list
Beispiel #4
0
def load_dataset(bids_df,
                 data_list,
                 transforms_params,
                 model_params,
                 target_suffix,
                 roi_params,
                 contrast_params,
                 slice_filter_params,
                 patch_filter_params,
                 slice_axis,
                 multichannel,
                 dataset_type="training",
                 requires_undo=False,
                 metadata_type=None,
                 object_detection_params=None,
                 soft_gt=False,
                 device=None,
                 cuda_available=None,
                 is_input_dropout=False,
                 **kwargs):
    """Get loader appropriate loader according to model type. Available loaders are Bids3DDataset for 3D data,
    BidsDataset for 2D data and HDF5Dataset for HeMIS.

    Args:
        bids_df (BidsDataframe): Object containing dataframe with all BIDS image files and their metadata.
        data_list (list): Subject names list.
        transforms_params (dict): Dictionary containing transformations for "training", "validation", "testing" (keys),
            eg output of imed_transforms.get_subdatasets_transforms.
        model_params (dict): Dictionary containing model parameters.
        target_suffix (list of str): List of suffixes for target masks.
        roi_params (dict): Contains ROI related parameters.
        contrast_params (dict): Contains image contrasts related parameters.
        slice_filter_params (dict): Contains slice_filter_params, see :doc:`configuration_file` for more details.
        patch_filter_params (dict): Contains patch_filter_params, see :doc:`configuration_file` for more details.
        slice_axis (string): Choice between "axial", "sagittal", "coronal" ; controls the axis used to extract the 2D
            data from 3D NifTI files. 2D PNG/TIF/JPG files use default "axial.
        multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
            contrast is processed individually (ie different sample / tensor).
        metadata_type (str): Choice between None, "mri_params", "contrasts".
        dataset_type (str): Choice between "training", "validation" or "testing".
        requires_undo (bool): If True, the transformations without undo_transform will be discarded.
        object_detection_params (dict): Object dection parameters.
        soft_gt (bool): If True, ground truths are not binarized before being fed to the network. Otherwise, ground
        truths are thresholded (0.5) after the data augmentation operations.
        is_input_dropout (bool): Return input with missing modalities.

    Returns:
        BidsDataset

    Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
    slice_filter_params, patch_filter_params and object_detection_params see :doc:`configuration_file`.
    """

    # Compose transforms
    tranform_lst, _ = imed_transforms.prepare_transforms(
        copy.deepcopy(transforms_params), requires_undo)

    # If ROICrop is not part of the transforms, then enforce no slice filtering based on ROI data.
    if TransformationKW.ROICROP not in transforms_params:
        roi_params[ROIParamsKW.SLICE_FILTER_ROI] = None

    if model_params[ModelParamsKW.NAME] == ConfigKW.MODIFIED_3D_UNET \
            or (ModelParamsKW.IS_2D in model_params and not model_params[ModelParamsKW.IS_2D]):
        dataset = Bids3DDataset(
            bids_df=bids_df,
            subject_file_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            model_params=model_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt,
            is_input_dropout=is_input_dropout)
    # elif model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
    #     dataset = imed_adaptative.HDF5Dataset(bids_df=bids_df,
    #                                           subject_file_lst=data_list,
    #                                           model_params=model_params,
    #                                           contrast_params=contrast_params,
    #                                           target_suffix=target_suffix,
    #                                           slice_axis=imed_utils.AXIS_DCT[slice_axis],
    #                                           transform=tranform_lst,
    #                                           metadata_choice=metadata_type,
    #                                           slice_filter_fn=SliceFilter(**slice_filter_params,
    #                                                                                         device=device,
    #                                                                                         cuda_available=cuda_available),
    #                                           roi_params=roi_params,
    #                                           object_detection_params=object_detection_params,
    #                                           soft_gt=soft_gt)
    else:
        # Task selection
        task = imed_utils.get_task(model_params[ModelParamsKW.NAME])

        dataset = BidsDataset(
            bids_df=bids_df,
            subject_file_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            model_params=model_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            slice_filter_fn=SliceFilter(**slice_filter_params,
                                        device=device,
                                        cuda_available=cuda_available),
            patch_filter_fn=PatchFilter(**patch_filter_params,
                                        is_train=False if dataset_type
                                        == "testing" else True),
            soft_gt=soft_gt,
            object_detection_params=object_detection_params,
            task=task,
            is_input_dropout=is_input_dropout)
        dataset.load_filenames()

    if model_params[ModelParamsKW.NAME] == ConfigKW.MODIFIED_3D_UNET:
        logger.info(
            f"Loaded {len(dataset)} volumes of shape {dataset.length} for the {dataset_type} set."
        )
    elif model_params[
            ModelParamsKW.NAME] != ConfigKW.HEMIS_UNET and dataset.length:
        logger.info(
            f"Loaded {len(dataset)} {slice_axis} patches of shape {dataset.length} for the {dataset_type} set."
        )
    else:
        logger.info(
            f"Loaded {len(dataset)} {slice_axis} slices for the { dataset_type} set."
        )

    return dataset