Ejemplo n.º 1
0
def segment_volume(folder_model, fname_images, gpu_number=0, options=None):
    """Segment an image.
    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.
    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple
            images to segment, e.i., len(fname_images) > 1.
        gpu_number (int): Number representing gpu number if available.
        options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename
            (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS
            lesion classification)
            e.g., spinal cord centerline, used to crop the image prior to segment it if provided.
            The segmentation is not performed on the slices that are empty in this image.
    Returns:
        list: List of nibabel objects containing the soft segmentation(s), one per prediction class.
        list: List of target suffix associated with each prediction in `pred_list`

    """
    # Define device
    cuda_available = torch.cuda.is_available()
    device = torch.device("cpu") if not cuda_available else torch.device(
        "cuda:" + str(gpu_number))

    # Check if model folder exists and get filenames
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    context = imed_config_manager.ConfigurationManager(
        fname_model_metadata).get_config()

    postpro_list = [
        'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small'
    ]
    if options is not None and any(pp in options for pp in postpro_list):
        postpro = {}
        if 'binarize_prediction' in options and options['binarize_prediction']:
            postpro['binarize_prediction'] = {
                "thr": options['binarize_prediction']
            }
        if 'keep_largest' in options and options['keep_largest'] is not None:
            if options['keep_largest']:
                postpro['keep_largest'] = {}
            # Remove key in context if value set to 0
            elif 'keep_largest' in context['postprocessing']:
                del context['postprocessing']['keep_largest']
        if 'fill_holes' in options and options['fill_holes'] is not None:
            if options['fill_holes']:
                postpro['fill_holes'] = {}
            # Remove key in context if value set to 0
            elif 'fill_holes' in context['postprocessing']:
                del context['postprocessing']['fill_holes']
        if 'remove_small' in options and options['remove_small'] and \
                ('mm' in options['remove_small'][-1] or 'vox' in options['remove_small'][-1]):
            unit = 'mm3' if 'mm3' in options['remove_small'][-1] else 'vox'
            thr = [int(t.replace(unit, "")) for t in options['remove_small']]
            postpro['remove_small'] = {"unit": unit, "thr": thr}

        context['postprocessing'].update(postpro)

    # LOADER
    loader_params = context["loader_parameters"]
    slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']]
    metadata = {}
    fname_roi = None
    fname_prior = options['fname_prior'] if (options is not None) and (
        'fname_prior' in options) else None
    if fname_prior is not None:
        if 'roi_params' in loader_params and loader_params['roi_params'][
                'suffix'] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        # If ROI is not provided then force center cropping
        if fname_roi is None and 'ROICrop' in context["transformation"].keys():
            print(
                "\n WARNING: fname_roi has not been specified, then a cropping around the center of the image is "
                "performed instead of a cropping around a Region of Interest.")

            context["transformation"] = dict(
                (key, value) if key != 'ROICrop' else ('CenterCrop', value)
                for (key, value) in context["transformation"].items())

        if 'object_detection_params' in context and \
                context['object_detection_params']['object_detection_path'] is not None:
            imed_obj_detect.bounding_box_prior(
                fname_prior, metadata, slice_axis,
                context['object_detection_params']['safety_factor'])
            metadata = [metadata] * len(fname_images)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and 'filter_empty_mask' in loader_params[
            "slice_filter_params"]:
        print(
            "\nWARNING: fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params["slice_filter_params"]["filter_empty_mask"] = False

    filename_pairs = [(fname_images, None, fname_roi,
                       metadata if isinstance(metadata, list) else [metadata])]

    kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \
                not context['default_model']['is_2d']
    if kernel_3D:
        ds = imed_loader.MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context["Modified3DUNet"]["length_3D"],
            stride=context["Modified3DUNet"]["stride_3D"])
    else:
        ds = imed_loader.MRI2DSegmentationDataset(
            filename_pairs,
            slice_axis=slice_axis,
            cache=True,
            transform=tranform_lst,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **loader_params["slice_filter_params"]))
        ds.load_filenames()

    if kernel_3D:
        print("\nLoaded {} {} volumes of shape {}.".format(
            len(ds), loader_params['slice_axis'],
            context['Modified3DUNet']['length_3D']))
    else:
        print("\nLoaded {} {} slices.".format(len(ds),
                                              loader_params['slice_axis']))

    model_params = {}
    if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']:
        metadata_dict = joblib.load(
            os.path.join(folder_model, 'metadata_dict.joblib'))
        for idx in ds.indexes:
            for i in range(len(idx)):
                idx[i]['input_metadata'][0][context['FiLMedUnet']
                                            ['metadata']] = options['metadata']
                idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict

        ds = imed_film.normalize_metadata(ds, None, context["debugging"],
                                          context['FiLMedUnet']['metadata'])
        onehotencoder = joblib.load(
            os.path.join(folder_model, 'one_hot_encoder.joblib'))

        model_params.update({
            "name":
            'FiLMedUnet',
            "film_onehotencoder":
            onehotencoder,
            "n_metadata":
            len([ll for l in onehotencoder.categories_ for ll in l])
        })

    # Data Loader
    data_loader = DataLoader(
        ds,
        batch_size=context["training_parameters"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    # MODEL
    if fname_model.endswith('.pt'):
        model = torch.load(fname_model, map_location=device)
        # Inference time
        model.eval()

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, volume, weight_matrix = False, None, None
    for i_batch, batch in enumerate(data_loader):
        with torch.no_grad():
            img = imed_utils.cuda(batch['input'],
                                  cuda_available=cuda_available)

            if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \
                    ('HeMISUnet' in context and context['HeMISUnet']['applied']):
                metadata = imed_training.get_metadata(batch["input_metadata"],
                                                      model_params)
                preds = model(img, metadata)

            else:
                preds = model(img) if fname_model.endswith(
                    '.pt') else onnx_inference(fname_model, img)

            preds = preds.cpu()

        # Set datatype to gt since prediction should be processed the same way as gt
        for b in batch['input_metadata']:
            for modality in b:
                modality['data_type'] = 'gt'

        # Reconstruct 3D object
        for i_slice in range(len(preds)):
            if "bounding_box" in batch['input_metadata'][i_slice][0]:
                imed_obj_detect.adjust_undo_transforms(
                    undo_transforms.transforms, batch, i_slice)

            batch['gt_metadata'] = [[metadata[0]] * preds.shape[1]
                                    for metadata in batch['input_metadata']]
            if kernel_3D:
                preds_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix)
                preds_list = [np.array(preds_undo)]
            else:
                # undo transformations
                preds_i_undo, metadata_idx = undo_transforms(
                    preds[i_slice],
                    batch["input_metadata"][i_slice],
                    data_type='gt')

                # Add new segmented slice to preds_list
                preds_list.append(np.array(preds_i_undo))
                # Store the slice index of preds_i_undo in the original 3D image
                slice_idx_list.append(
                    int(batch['input_metadata'][i_slice][0]['slice_index']))

            # If last batch and last sample of this batch, then reconstruct 3D object
            if (i_batch == len(data_loader) - 1
                    and i_slice == len(batch['gt']) - 1) or last_sample_bool:
                pred_nib = pred_to_nib(
                    data_lst=preds_list,
                    fname_ref=fname_images[0],
                    fname_out=None,
                    z_lst=slice_idx_list,
                    slice_axis=slice_axis,
                    kernel_dim='3d' if kernel_3D else '2d',
                    debug=False,
                    bin_thr=-1,
                    postprocessing=context['postprocessing'])

                pred_list = split_classes(pred_nib)
                target_list = context['loader_parameters']['target_suffix']

    return pred_list, target_list
Ejemplo n.º 2
0
def test_HeMIS(p=0.0001):
    print('[INFO]: Starting test ... \n')
    training_transform_dict = {
        "Resample":
            {
                "wspace": 0.75,
                "hspace": 0.75
            },
        "CenterCrop":
            {
                "size": [48, 48]
            },
        "NumpyToTensor": {}
    }

    transform_lst, _ = imed_transforms.prepare_transforms(training_transform_dict)

    roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None}

    train_lst = ['sub-unf01']
    contrasts = ['T1w', 'T2w', 'T2star']

    print('[INFO]: Creating dataset ...\n')
    model_params = {
            "name": "HeMISUnet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "missing_probability": 0.00001,
            "missing_probability_growth": 0.9,
            "contrasts": ["T1w", "T2w"],
            "ram": False,
            "path_hdf5": 'testing_data/mytestfile.hdf5',
            "csv_path": 'testing_data/hdf5.csv',
            "target_lst": ["T2w"],
            "roi_lst": ["T2w"]
        }
    contrast_params = {
        "contrast_lst": ['T1w', 'T2w', 'T2star'],
        "balance": {}
    }
    dataset = imed_adaptative.HDF5Dataset(root_dir=PATH_BIDS,
                                          subject_lst=train_lst,
                                          model_params=model_params,
                                          contrast_params=contrast_params,
                                          target_suffix=["_lesion-manual"],
                                          slice_axis=2,
                                          transform=transform_lst,
                                          metadata_choice=False,
                                          dim=2,
                                          slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True,
                                                                                 filter_empty_mask=True),
                                          roi_params=roi_params)

    dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
    print("[INFO]: Dataset RAM status:")
    print(dataset.status)
    print("[INFO]: In memory Dataframe:")
    print(dataset.dataframe)

    # TODO
    # ds_train.filter_roi(nb_nonzero_thr=10)

    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              shuffle=True, pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    model = models.HeMISUnet(contrasts=contrasts,
                             depth=3,
                             drop_rate=DROPOUT,
                             bn_momentum=BN)

    print(model)
    cuda_available = torch.cuda.is_available()

    if cuda_available:
        torch.cuda.set_device(GPU_NUMBER)
        print("Using GPU number {}".format(GPU_NUMBER))
        model.cuda()

    # Initialing Optimizer and scheduler
    step_scheduler_batch = False
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

    load_lst, reload_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [], []

    for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"):
        start_time = time.time()

        start_init = time.time()
        lr = scheduler.get_last_lr()[0]
        model.train()

        tot_init = time.time() - start_init
        init_lst.append(tot_init)

        num_steps = 0
        start_gen = 0
        for i, batch in enumerate(train_loader):
            if i > 0:
                tot_gen = time.time() - start_gen
                gen_lst.append(tot_gen)

            start_load = time.time()
            input_samples, gt_samples = imed_utils.unstack_tensors(batch["input"]), batch["gt"]

            print(batch["input_metadata"][0][0]["missing_mod"])
            missing_mod = imed_training.get_metadata(batch["input_metadata"], model_params)

            print("Number of missing contrasts = {}."
                  .format(len(input_samples) * len(input_samples[0]) - missing_mod.sum()))
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape, gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            tot_load = time.time() - start_load
            load_lst.append(tot_load)

            start_pred = time.time()
            preds = model(var_input, missing_mod)
            tot_pred = time.time() - start_pred
            pred_lst.append(tot_pred)

            start_opt = time.time()
            loss = - losses.DiceLoss()(preds, var_gt)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()

            num_steps += 1
            tot_opt = time.time() - start_opt
            opt_lst.append(tot_opt)

            start_gen = time.time()

        start_schedul = time.time()
        if not step_scheduler_batch:
            scheduler.step()
        tot_schedul = time.time() - start_schedul
        schedul_lst.append(tot_schedul)

        start_reload = time.time()
        print("[INFO]: Updating Dataset")
        p = p ** (2 / 3)
        dataset.update(p=p)
        print("[INFO]: Reloading dataset")
        train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  shuffle=True, pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)
        tot_reload = time.time() - start_reload
        reload_lst.append(tot_reload)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

    print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst)))
    print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst)))
    print('Mean SD reload {} -- {}'.format(np.mean(reload_lst), np.std(reload_lst)))
    print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst)))
    print('Mean SD opt {} --  {}'.format(np.mean(opt_lst), np.std(opt_lst)))
    print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst)))
    print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst), np.std(schedul_lst)))
Ejemplo n.º 3
0
def test_hdf5(download_data_testing_test_files, loader_parameters):
    print('[INFO]: Starting test ... \n')

    bids_df = imed_loader_utils.BidsDataframe(loader_parameters,
                                              __tmp_dir__,
                                              derivatives=True)

    contrast_params = loader_parameters["contrast_params"]
    target_suffix = loader_parameters["target_suffix"]
    roi_params = loader_parameters["roi_params"]

    train_lst = ['sub-unf01']

    training_transform_dict = {
        "Resample": {
            "wspace": 0.75,
            "hspace": 0.75
        },
        "CenterCrop": {
            "size": [48, 48]
        },
        "NumpyToTensor": {}
    }
    transform_lst, _ = imed_transforms.prepare_transforms(
        training_transform_dict)

    bids_to_hdf5 = imed_adaptative.BIDStoHDF5(
        bids_df=bids_df,
        subject_file_lst=train_lst,
        path_hdf5=os.path.join(__data_testing_dir__, 'mytestfile.hdf5'),
        target_suffix=target_suffix,
        roi_params=roi_params,
        contrast_lst=contrast_params["contrast_lst"],
        metadata_choice="contrast",
        transform=transform_lst,
        contrast_balance={},
        slice_axis=2,
        slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True,
                                                      filter_empty_mask=True))

    # Checking architecture
    def print_attrs(name, obj):
        print("\nName of the object: {}".format(name))
        print("Type: {}".format(type(obj)))
        print("Including the following attributes:")
        for key, val in obj.attrs.items():
            print("    %s: %s" % (key, val))

    print('\n[INFO]: HDF5 architecture:')
    with h5py.File(bids_to_hdf5.path_hdf5, "a") as hdf5_file:
        hdf5_file.visititems(print_attrs)
        print('\n[INFO]: HDF5 file successfully generated.')
        print('[INFO]: Generating dataframe ...\n')

        df = imed_adaptative.Dataframe(hdf5_file=hdf5_file,
                                       contrasts=['T1w', 'T2w', 'T2star'],
                                       path=os.path.join(
                                           __data_testing_dir__, 'hdf5.csv'),
                                       target_suffix=['T1w', 'T2w', 'T2star'],
                                       roi_suffix=['T1w', 'T2w', 'T2star'],
                                       dim=2,
                                       filter_slices=True)

        print(df.df)

        print('\n[INFO]: Dataframe successfully generated. ')
        print('[INFO]: Creating dataset ...\n')

        model_params = {
            "name": "HeMISUnet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "missing_probability": 0.00001,
            "missing_probability_growth": 0.9,
            "contrasts": ["T1w", "T2w"],
            "ram": False,
            "path_hdf5": os.path.join(__data_testing_dir__, 'mytestfile.hdf5'),
            "csv_path": os.path.join(__data_testing_dir__, 'hdf5.csv'),
            "target_lst": ["T2w"],
            "roi_lst": ["T2w"]
        }

        dataset = imed_adaptative.HDF5Dataset(
            bids_df=bids_df,
            subject_file_lst=train_lst,
            target_suffix=target_suffix,
            slice_axis=2,
            model_params=model_params,
            contrast_params=contrast_params,
            transform=transform_lst,
            metadata_choice=False,
            dim=2,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                filter_empty_input=True, filter_empty_mask=True),
            roi_params=roi_params)

        dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
        print("Dataset RAM status:")
        print(dataset.status)
        print("In memory Dataframe:")
        print(dataset.dataframe)
        print('\n[INFO]: Test passed successfully. ')

        print("\n[INFO]: Starting loader test ...")

        device = torch.device(
            "cuda:" + str(GPU_ID) if torch.cuda.is_available() else "cpu")
        cuda_available = torch.cuda.is_available()
        if cuda_available:
            torch.cuda.set_device(device)
            print("Using GPU ID {}".format(device))

        train_loader = DataLoader(dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=False,
                                  pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)

        for i, batch in enumerate(train_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape,
                                          gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            break
        print(
            "Congrats your dataloader works! You can go home now and get a beer."
        )
        return 0
Ejemplo n.º 4
0
def load_dataset(data_list,
                 bids_path,
                 transforms_params,
                 model_params,
                 target_suffix,
                 roi_params,
                 contrast_params,
                 slice_filter_params,
                 slice_axis,
                 multichannel,
                 dataset_type="training",
                 requires_undo=False,
                 metadata_type=None,
                 object_detection_params=None,
                 soft_gt=False,
                 device=None,
                 cuda_available=None,
                 **kwargs):
    """Get loader appropriate loader according to model type. Available loaders are Bids3DDataset for 3D data,
    BidsDataset for 2D data and HDF5Dataset for HeMIS.

    Args:
        data_list (list): Subject names list.
        bids_path (str): Path to the BIDS dataset.
        transforms_params (dict): Dictionary containing transformations for "training", "validation", "testing" (keys),
            eg output of imed_transforms.get_subdatasets_transforms.
        model_params (dict): Dictionary containing model parameters.
        target_suffix (list of str): List of suffixes for target masks.
        roi_params (dict): Contains ROI related parameters.
        contrast_params (dict): Contains image contrasts related parameters.
        slice_filter_params (dict): Contains slice_filter parameters, see :doc:`configuration_file` for more details.
        slice_axis (string): Choice between "axial", "sagittal", "coronal" ; controls the axis used to extract the 2D
            data.
        multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
            contrast is processed individually (ie different sample / tensor).
        metadata_type (str): Choice between None, "mri_params", "contrasts".
        dataset_type (str): Choice between "training", "validation" or "testing".
        requires_undo (bool): If True, the transformations without undo_transform will be discarded.
        object_detection_params (dict): Object dection parameters.
        soft_gt (bool): If True, ground truths will be converted to float32, otherwise to uint8 and binarized
            (to save memory).
    Returns:
        BidsDataset

    Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
    slice_filter_params and object_detection_params see :doc:`configuration_file`.
    """
    # Compose transforms
    tranform_lst, _ = imed_transforms.prepare_transforms(
        copy.deepcopy(transforms_params), requires_undo)

    # If ROICrop is not part of the transforms, then enforce no slice filtering based on ROI data.
    if 'ROICrop' not in transforms_params:
        roi_params["slice_filter_roi"] = None

    if model_params["name"] == "Modified3DUNet" or ('is_2d' in model_params and
                                                    not model_params['is_2d']):
        dataset = Bids3DDataset(
            bids_path,
            subject_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            model_params=model_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)

    elif model_params["name"] == "HeMISUnet":
        dataset = imed_adaptative.HDF5Dataset(
            root_dir=bids_path,
            subject_lst=data_list,
            model_params=model_params,
            contrast_params=contrast_params,
            target_suffix=target_suffix,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            metadata_choice=metadata_type,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **slice_filter_params,
                device=device,
                cuda_available=cuda_available),
            roi_params=roi_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)
    else:
        # Task selection
        task = imed_utils.get_task(model_params["name"])

        dataset = BidsDataset(bids_path,
                              subject_lst=data_list,
                              target_suffix=target_suffix,
                              roi_params=roi_params,
                              contrast_params=contrast_params,
                              metadata_choice=metadata_type,
                              slice_axis=imed_utils.AXIS_DCT[slice_axis],
                              transform=tranform_lst,
                              multichannel=multichannel,
                              slice_filter_fn=imed_loader_utils.SliceFilter(
                                  **slice_filter_params,
                                  device=device,
                                  cuda_available=cuda_available),
                              soft_gt=soft_gt,
                              object_detection_params=object_detection_params,
                              task=task)
        dataset.load_filenames()

    if model_params["name"] != "Modified3DUNet":
        print("Loaded {} {} slices for the {} set.".format(
            len(dataset), slice_axis, dataset_type))
    else:
        print("Loaded {} volumes of size {} for the {} set.".format(
            len(dataset), slice_axis, dataset_type))

    return dataset
Ejemplo n.º 5
0
def test_image_orientation(download_data_testing_test_files,
                           loader_parameters):
    device = torch.device("cuda:" +
                          str(GPU_ID) if torch.cuda.is_available() else "cpu")
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.set_device(device)
        logger.info(f"Using GPU ID {device}")

    bids_df = BidsDataframe(loader_parameters, __tmp_dir__, derivatives=True)

    contrast_params = loader_parameters["contrast_params"]
    target_suffix = loader_parameters["target_suffix"]
    roi_params = loader_parameters["roi_params"]

    train_lst = ['sub-unf01_T1w.nii.gz']

    training_transform_dict = {
        "Resample": {
            "wspace": 1.5,
            "hspace": 1,
            "dspace": 3
        },
        "CenterCrop": {
            "size": [176, 128, 160]
        },
        "NormalizeInstance": {
            "applied_to": ['im']
        }
    }

    tranform_lst, training_undo_transform = imed_transforms.prepare_transforms(
        training_transform_dict)

    model_params = {
        "name": "Modified3DUNet",
        "dropout_rate": 0.3,
        "bn_momentum": 0.9,
        "depth": 2,
        "in_channel": 1,
        "out_channel": 1,
        "length_3D": [176, 128, 160],
        "stride_3D": [176, 128, 160],
        "attention": False,
        "n_filters": 8
    }

    for dim in ['2d', '3d']:
        for slice_axis in [0, 1, 2]:
            if dim == '2d':
                ds = BidsDataset(bids_df=bids_df,
                                 subject_file_lst=train_lst,
                                 target_suffix=target_suffix,
                                 contrast_params=contrast_params,
                                 model_params=model_params,
                                 metadata_choice=False,
                                 slice_axis=slice_axis,
                                 transform=tranform_lst,
                                 multichannel=False)
                ds.load_filenames()
            else:
                ds = Bids3DDataset(bids_df=bids_df,
                                   subject_file_lst=train_lst,
                                   target_suffix=target_suffix,
                                   model_params=model_params,
                                   contrast_params=contrast_params,
                                   metadata_choice=False,
                                   slice_axis=slice_axis,
                                   transform=tranform_lst,
                                   multichannel=False)

            loader = DataLoader(ds,
                                batch_size=1,
                                shuffle=True,
                                pin_memory=True,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=1)

            input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[
                0]
            segpair = SegmentationPair(input_filename,
                                       gt_filename,
                                       metadata=metadata,
                                       slice_axis=slice_axis)
            nib_original = nib.load(gt_filename[0])
            # Get image with original, ras and hwd orientations
            input_init = nib_original.get_fdata()
            input_ras = nib.as_closest_canonical(nib_original).get_fdata()
            img, gt = segpair.get_pair_data()
            input_hwd = gt[0]

            pred_tmp_lst, z_tmp_lst = [], []
            for i, batch in enumerate(loader):
                # batch["input_metadata"] = batch["input_metadata"][0]  # Take only metadata from one input
                # batch["gt_metadata"] = batch["gt_metadata"][0]  # Take only metadata from one label

                for smp_idx in range(len(batch['gt'])):
                    # undo transformations
                    if dim == '2d':
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                        # add new sample to pred_tmp_lst
                        pred_tmp_lst.append(preds_idx_undo[0])
                        z_tmp_lst.append(
                            int(batch['input_metadata'][smp_idx][0]
                                ['slice_index']))

                    else:
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                    fname_ref = metadata_idx[0]['gt_filenames'][0]

                    if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d':
                        # save the completely processed file as a nii
                        nib_ref = nib.load(fname_ref)
                        nib_ref_can = nib.as_closest_canonical(nib_ref)

                        if dim == '2d':
                            tmp_lst = []
                            for z in range(nib_ref_can.header.get_data_shape()
                                           [slice_axis]):
                                tmp_lst.append(
                                    pred_tmp_lst[z_tmp_lst.index(z)])
                            arr = np.stack(tmp_lst, axis=-1)
                        else:
                            arr = np.array(preds_idx_undo[0])

                        # verify image after transform, undo transform and 3D reconstruction
                        input_hwd_2 = imed_postpro.threshold_predictions(arr)
                        # Some difference are generated due to transform and undo transform
                        # (e.i. Resample interpolation)
                        assert imed_metrics.dice_score(input_hwd_2,
                                                       input_hwd) >= 0.8
                        input_ras_2 = imed_loader_utils.orient_img_ras(
                            input_hwd_2, slice_axis)
                        assert imed_metrics.dice_score(input_ras_2,
                                                       input_ras) >= 0.8
                        input_init_2 = imed_loader_utils.reorient_image(
                            input_hwd_2, slice_axis, nib_ref, nib_ref_can)
                        assert imed_metrics.dice_score(input_init_2,
                                                       input_init) >= 0.8

                        # re-init pred_stack_lst
                        pred_tmp_lst, z_tmp_lst = [], []
Ejemplo n.º 6
0
def segment_volume(folder_model: str,
                   fname_images: list,
                   gpu_id: int = 0,
                   options: dict = None):
    """Segment an image.

    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.

    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple
            images to segment, e.i., len(fname_images) > 1.
        gpu_id (int): Number representing gpu number if available. Currently does NOT support multiple GPU segmentation.
        options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename
            (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS
            lesion classification)
            e.g., spinal cord centerline, used to crop the image prior to segment it if provided.
            The segmentation is not performed on the slices that are empty in this image.

    Returns:
        list: List of nibabel objects containing the soft segmentation(s), one per prediction class.
        list: List of target suffix associated with each prediction in `pred_list`

    """

    # Check if model folder exists and get filenames to be stored as string
    fname_model: str
    fname_model_metadata: str
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    context = imed_config_manager.ConfigurationManager(
        fname_model_metadata).get_config()

    postpro_list = [
        'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small'
    ]
    if options is not None and any(pp in options for pp in postpro_list):
        set_postprocessing_options(options, context)

    # LOADER
    loader_params = context["loader_parameters"]
    slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']]
    metadata = {}
    fname_roi = None
    fname_prior = options['fname_prior'] if (options is not None) and (
        'fname_prior' in options) else None
    if fname_prior is not None:
        if 'roi_params' in loader_params and loader_params['roi_params'][
                'suffix'] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        metadata = process_transformations(context, fname_roi, fname_prior,
                                           metadata, slice_axis, fname_images)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and 'filter_empty_mask' in loader_params[
            "slice_filter_params"]:
        logger.warning(
            "fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params["slice_filter_params"]["filter_empty_mask"] = False

    filename_pairs = [(fname_images, None, fname_roi,
                       metadata if isinstance(metadata, list) else [metadata])]

    kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \
                not context['default_model']['is_2d']
    if kernel_3D:
        ds = imed_loader.MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context["Modified3DUNet"]["length_3D"],
            stride=context["Modified3DUNet"]["stride_3D"])
        logger.info(
            f"Loaded {len(ds)} {loader_params['slice_axis']} volumes of shape "
            f"{context['Modified3DUNet']['length_3D']}.")
    else:
        ds = imed_loader.MRI2DSegmentationDataset(
            filename_pairs,
            slice_axis=slice_axis,
            cache=True,
            transform=tranform_lst,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **loader_params["slice_filter_params"]))
        ds.load_filenames()
        logger.info(f"Loaded {len(ds)} {loader_params['slice_axis']} slices.")

    model_params = {}
    if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']:
        onehotencoder = get_onehotencoder(context, folder_model, options, ds)
        model_params.update({
            "name":
            'FiLMedUnet',
            "film_onehotencoder":
            onehotencoder,
            "n_metadata":
            len([ll for l in onehotencoder.categories_ for ll in l])
        })

    # Data Loader
    data_loader = DataLoader(
        ds,
        batch_size=context["training_parameters"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, weight_matrix, volume = False, None, None
    for i_batch, batch in enumerate(data_loader):
        preds = get_preds(context, fname_model, model_params, gpu_id, batch)

        # Set datatype to gt since prediction should be processed the same way as gt
        for b in batch['input_metadata']:
            for modality in b:
                modality['data_type'] = 'gt'

        # Reconstruct 3D object
        pred_list, target_list, last_sample_bool, weight_matrix, volume = reconstruct_3d_object(
            context, batch, undo_transforms, preds, preds_list, kernel_3D,
            slice_axis, slice_idx_list, data_loader, fname_images, i_batch,
            last_sample_bool, weight_matrix, volume)

    return pred_list, target_list
Ejemplo n.º 7
0
def segment_volume(folder_model: str,
                   fname_images: list,
                   gpu_id: int = 0,
                   options: dict = None):
    """Segment an image.

    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.

    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple
            images to segment, e.i., len(fname_images) > 1.
        gpu_id (int): Number representing gpu number if available. Currently does NOT support multiple GPU segmentation.
        options (dict): This can optionally contain any of the following key-value pairs:

            * 'binarize_prediction': (float) Binarize segmentation with specified threshold. \
                Predictions below the threshold become 0, and predictions above or equal to \
                threshold become 1. Set to -1 for no thresholding (i.e., soft segmentation).
            * 'binarize_maxpooling': (bool) Binarize by setting to 1 the voxel having the maximum prediction across \
                all classes. Useful for multiclass models.
            * 'fill_holes': (bool) Fill small holes in the segmentation.
            * 'keep_largest': (bool) Keep the largest connected-object for each class from the output segmentation.
            * 'remove_small': (list of str) Minimal object size to keep with unit (mm3 or vox). A single value can be provided \
                              or one value per prediction class. Single value example: ["1mm3"], ["5vox"]. Multiple values \
                              example: ["10", "20", "10vox"] (remove objects smaller than 10 voxels for class 1 and 3, \
                              and smaller than 20 voxels for class 2).
            * 'pixel_size': (list of float) List of microscopy pixel size in micrometers. \
                            Length equals 2 [PixelSizeX, PixelSizeY] for 2D or 3 [PixelSizeX, PixelSizeY, PixelSizeZ] for 3D, \
                            where X is the width, Y the height and Z the depth of the image.
            * 'pixel_size_units': (str) Units of pixel size (Must be either "mm", "um" or "nm")
            * 'overlap_2D': (list of int) List of overlaps in pixels for 2D patching. Length equals 2 [OverlapX, OverlapY], \
                            where X is the width and Y the height of the image.
            * 'metadata': (str) Film metadata.
            * 'fname_prior': (str) An image filename (e.g., .nii.gz) containing processing information \
                (e.g., spinal cord segmentation, spinal location or MS lesion classification, spinal cord centerline), \
                used to crop the image prior to segment it if provided. \
                The segmentation is not performed on the slices that are empty in this image.

    Returns:
        list, list: List of nibabel objects containing the soft segmentation(s), one per prediction class, \
            List of target suffix associated with each prediction in `pred_list`

    """

    # Check if model folder exists and get filenames to be stored as string
    fname_model: str
    fname_model_metadata: str
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    context = imed_config_manager.ConfigurationManager(
        fname_model_metadata).get_config()

    postpro_list = [
        'binarize_prediction', 'binarize_maxpooling', 'keep_largest',
        ' fill_holes', 'remove_small'
    ]
    if options is not None and any(pp in options for pp in postpro_list):
        set_postprocessing_options(options, context)

    # LOADER
    loader_params = context[ConfigKW.LOADER_PARAMETERS]
    slice_axis = imed_utils.AXIS_DCT[loader_params[LoaderParamsKW.SLICE_AXIS]]
    metadata = {}
    fname_roi = None

    if (options is not None) and (OptionKW.FNAME_PRIOR in options):
        fname_prior = options.get(OptionKW.FNAME_PRIOR)
    else:
        fname_prior = None

    if fname_prior is not None:
        if LoaderParamsKW.ROI_PARAMS in loader_params and loader_params[
                LoaderParamsKW.ROI_PARAMS][ROIParamsKW.SUFFIX] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        metadata = process_transformations(context, fname_roi, fname_prior,
                                           metadata, slice_axis, fname_images)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context[ConfigKW.TRANSFORMATION])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and SliceFilterParamsKW.FILTER_EMPTY_MASK in loader_params[
            LoaderParamsKW.SLICE_FILTER_PARAMS]:
        logger.warning(
            "fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params[LoaderParamsKW.SLICE_FILTER_PARAMS][
            SliceFilterParamsKW.FILTER_EMPTY_MASK] = False

    kernel_3D = bool(ConfigKW.MODIFIED_3D_UNET in context and context[ConfigKW.MODIFIED_3D_UNET][ModelParamsKW.APPLIED]) or \
                not context[ConfigKW.DEFAULT_MODEL][ModelParamsKW.IS_2D]

    # Assign length_2D and stride_2D for 2D patching
    length_2D = context[ConfigKW.DEFAULT_MODEL][ModelParamsKW.LENGTH_2D] if \
        ModelParamsKW.LENGTH_2D in context[ConfigKW.DEFAULT_MODEL] else []
    stride_2D = context[ConfigKW.DEFAULT_MODEL][ModelParamsKW.STRIDE_2D] if \
        ModelParamsKW.STRIDE_2D in context[ConfigKW.DEFAULT_MODEL] else []
    is_2d_patch = bool(length_2D)

    if is_2d_patch and (options is not None) and (OptionKW.OVERLAP_2D
                                                  in options):
        overlap_2D = options.get(OptionKW.OVERLAP_2D)
        # Swap OverlapX and OverlapY resulting in an array in order [OverlapY, OverlapX]
        # to match length_2D and stride_2D in [Height, Width] orientation.
        overlap_2D[1], overlap_2D[0] = overlap_2D[0], overlap_2D[1]
        # Adjust stride_2D with overlap_2D
        stride_2D = [x1 - x2 for (x1, x2) in zip(length_2D, overlap_2D)]

    # Add microscopy pixel size and pixel size units from options to metadata for filenames_pairs
    if (options is not None) and (OptionKW.PIXEL_SIZE in options):
        metadata[MetadataKW.PIXEL_SIZE] = options.get(OptionKW.PIXEL_SIZE)
    if (options is not None) and (OptionKW.PIXEL_SIZE_UNITS in options):
        metadata[MetadataKW.PIXEL_SIZE_UNITS] = options.get(
            OptionKW.PIXEL_SIZE_UNITS)

    filename_pairs = [(fname_images, None, fname_roi,
                       metadata if isinstance(metadata, list) else [metadata])]

    if kernel_3D:
        ds = MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context[ConfigKW.MODIFIED_3D_UNET][ModelParamsKW.LENGTH_3D],
            stride=context[ConfigKW.MODIFIED_3D_UNET][ModelParamsKW.STRIDE_3D],
            slice_axis=slice_axis)
        logger.info(
            f"Loaded {len(ds)} {loader_params[LoaderParamsKW.SLICE_AXIS]} volumes of shape "
            f"{context[ConfigKW.MODIFIED_3D_UNET][ModelParamsKW.LENGTH_3D]}.")
    else:
        ds = MRI2DSegmentationDataset(
            filename_pairs,
            length=length_2D,
            stride=stride_2D,
            slice_axis=slice_axis,
            nibabel_cache=True,
            transform=tranform_lst,
            slice_filter_fn=SliceFilter(
                **loader_params[LoaderParamsKW.SLICE_FILTER_PARAMS]))
        ds.load_filenames()
        if is_2d_patch:
            logger.info(
                f"Loaded {len(ds)} {loader_params[LoaderParamsKW.SLICE_AXIS]} patches of shape {length_2D}."
            )
        else:
            logger.info(
                f"Loaded {len(ds)} {loader_params[LoaderParamsKW.SLICE_AXIS]} slices."
            )

    model_params = {}
    if ConfigKW.FILMED_UNET in context and context[ConfigKW.FILMED_UNET][
            ModelParamsKW.APPLIED]:
        onehotencoder = get_onehotencoder(context, folder_model, options, ds)
        model_params.update({
            ModelParamsKW.NAME:
            ConfigKW.FILMED_UNET,
            ModelParamsKW.FILM_ONEHOTENCODER:
            onehotencoder,
            ModelParamsKW.N_METADATA:
            len([ll for l in onehotencoder.categories_ for ll in l])
        })

    # Data Loader
    data_loader = DataLoader(ds,
                             batch_size=context[ConfigKW.TRAINING_PARAMETERS][
                                 TrainingParamsKW.BATCH_SIZE],
                             shuffle=False,
                             pin_memory=True,
                             collate_fn=imed_loader_utils.imed_collate,
                             num_workers=0)

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, weight_matrix, volume, image = False, None, None, None
    for i_batch, batch in enumerate(data_loader):
        preds = get_preds(context, fname_model, model_params, gpu_id, batch)

        # Set datatype to gt since prediction should be processed the same way as gt
        for b in batch[MetadataKW.INPUT_METADATA]:
            for modality in b:
                modality['data_type'] = 'gt'

        # Reconstruct 3D object
        pred_list, target_list, last_sample_bool, weight_matrix, volume, image = reconstruct_3d_object(
            context, batch, undo_transforms, preds, preds_list, kernel_3D,
            is_2d_patch, slice_axis, slice_idx_list, data_loader, fname_images,
            i_batch, last_sample_bool, weight_matrix, volume, image)

    return pred_list, target_list
Ejemplo n.º 8
0
def load_dataset(bids_df,
                 data_list,
                 transforms_params,
                 model_params,
                 target_suffix,
                 roi_params,
                 contrast_params,
                 slice_filter_params,
                 patch_filter_params,
                 slice_axis,
                 multichannel,
                 dataset_type="training",
                 requires_undo=False,
                 metadata_type=None,
                 object_detection_params=None,
                 soft_gt=False,
                 device=None,
                 cuda_available=None,
                 is_input_dropout=False,
                 **kwargs):
    """Get loader appropriate loader according to model type. Available loaders are Bids3DDataset for 3D data,
    BidsDataset for 2D data and HDF5Dataset for HeMIS.

    Args:
        bids_df (BidsDataframe): Object containing dataframe with all BIDS image files and their metadata.
        data_list (list): Subject names list.
        transforms_params (dict): Dictionary containing transformations for "training", "validation", "testing" (keys),
            eg output of imed_transforms.get_subdatasets_transforms.
        model_params (dict): Dictionary containing model parameters.
        target_suffix (list of str): List of suffixes for target masks.
        roi_params (dict): Contains ROI related parameters.
        contrast_params (dict): Contains image contrasts related parameters.
        slice_filter_params (dict): Contains slice_filter_params, see :doc:`configuration_file` for more details.
        patch_filter_params (dict): Contains patch_filter_params, see :doc:`configuration_file` for more details.
        slice_axis (string): Choice between "axial", "sagittal", "coronal" ; controls the axis used to extract the 2D
            data from 3D NifTI files. 2D PNG/TIF/JPG files use default "axial.
        multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
            contrast is processed individually (ie different sample / tensor).
        metadata_type (str): Choice between None, "mri_params", "contrasts".
        dataset_type (str): Choice between "training", "validation" or "testing".
        requires_undo (bool): If True, the transformations without undo_transform will be discarded.
        object_detection_params (dict): Object dection parameters.
        soft_gt (bool): If True, ground truths are not binarized before being fed to the network. Otherwise, ground
        truths are thresholded (0.5) after the data augmentation operations.
        is_input_dropout (bool): Return input with missing modalities.

    Returns:
        BidsDataset

    Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
    slice_filter_params, patch_filter_params and object_detection_params see :doc:`configuration_file`.
    """

    # Compose transforms
    tranform_lst, _ = imed_transforms.prepare_transforms(
        copy.deepcopy(transforms_params), requires_undo)

    # If ROICrop is not part of the transforms, then enforce no slice filtering based on ROI data.
    if TransformationKW.ROICROP not in transforms_params:
        roi_params[ROIParamsKW.SLICE_FILTER_ROI] = None

    if model_params[ModelParamsKW.NAME] == ConfigKW.MODIFIED_3D_UNET \
            or (ModelParamsKW.IS_2D in model_params and not model_params[ModelParamsKW.IS_2D]):
        dataset = Bids3DDataset(
            bids_df=bids_df,
            subject_file_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            model_params=model_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt,
            is_input_dropout=is_input_dropout)
    # elif model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET:
    #     dataset = imed_adaptative.HDF5Dataset(bids_df=bids_df,
    #                                           subject_file_lst=data_list,
    #                                           model_params=model_params,
    #                                           contrast_params=contrast_params,
    #                                           target_suffix=target_suffix,
    #                                           slice_axis=imed_utils.AXIS_DCT[slice_axis],
    #                                           transform=tranform_lst,
    #                                           metadata_choice=metadata_type,
    #                                           slice_filter_fn=SliceFilter(**slice_filter_params,
    #                                                                                         device=device,
    #                                                                                         cuda_available=cuda_available),
    #                                           roi_params=roi_params,
    #                                           object_detection_params=object_detection_params,
    #                                           soft_gt=soft_gt)
    else:
        # Task selection
        task = imed_utils.get_task(model_params[ModelParamsKW.NAME])

        dataset = BidsDataset(
            bids_df=bids_df,
            subject_file_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            model_params=model_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            slice_filter_fn=SliceFilter(**slice_filter_params,
                                        device=device,
                                        cuda_available=cuda_available),
            patch_filter_fn=PatchFilter(**patch_filter_params,
                                        is_train=False if dataset_type
                                        == "testing" else True),
            soft_gt=soft_gt,
            object_detection_params=object_detection_params,
            task=task,
            is_input_dropout=is_input_dropout)
        dataset.load_filenames()

    if model_params[ModelParamsKW.NAME] == ConfigKW.MODIFIED_3D_UNET:
        logger.info(
            f"Loaded {len(dataset)} volumes of shape {dataset.length} for the {dataset_type} set."
        )
    elif model_params[
            ModelParamsKW.NAME] != ConfigKW.HEMIS_UNET and dataset.length:
        logger.info(
            f"Loaded {len(dataset)} {slice_axis} patches of shape {dataset.length} for the {dataset_type} set."
        )
    else:
        logger.info(
            f"Loaded {len(dataset)} {slice_axis} slices for the { dataset_type} set."
        )

    return dataset
Ejemplo n.º 9
0
def segment_volume(folder_model, fname_image, fname_prior=None, gpu_number=0):
    """Segment an image.

    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.

    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_image (str): image filename (e.g. .nii.gz) to segment.
        fname_prior (str): Image filename (e.g. .nii.gz) containing processing information (e.i. spinal cord
            segmentation, spinal location or MS lesion classification)

            e.g. spinal cord centerline, used to crop the image prior to segment it if provided.
            The segmentation is not performed on the slices that are empty in this image.
        gpu_number (int): Number representing gpu number if available.

    Returns:
        nibabelObject: Object containing the soft segmentation.
    """
    # Define device
    cuda_available = torch.cuda.is_available()
    device = torch.device("cpu") if not cuda_available else torch.device(
        "cuda:" + str(gpu_number))

    # Check if model folder exists and get filenames
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    with open(fname_model_metadata, "r") as fhandle:
        context = json.load(fhandle)

    # LOADER
    loader_params = context["loader_parameters"]
    slice_axis = AXIS_DCT[loader_params['slice_axis']]
    metadata = {}
    fname_roi = None
    if fname_prior is not None:
        if 'roi_params' in loader_params and loader_params['roi_params'][
                'suffix'] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        # If ROI is not provided then force center cropping
        if fname_roi is None and 'ROICrop' in context["transformation"].keys():
            print(
                "\nWARNING: fname_roi has not been specified, then a cropping around the center of the image is performed"
                " instead of a cropping around a Region of Interest.")
            context["transformation"] = dict(
                (key, value) if key != 'ROICrop' else ('CenterCrop', value)
                for (key, value) in context["transformation"].items())

        if 'object_detection_params' in context and \
                context['object_detection_params']['object_detection_path'] is not None:
            imed_obj_detect.bounding_box_prior(fname_prior, metadata,
                                               slice_axis)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and 'filter_empty_mask' in loader_params[
            "slice_filter_params"]:
        print(
            "\nWARNING: fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params["slice_filter_params"]["filter_empty_mask"] = False

    filename_pairs = [([fname_image], None, fname_roi, [metadata])]

    kernel_3D = bool('UNet3D' in context and context['UNet3D']['applied'])
    if kernel_3D:
        ds = imed_loader.MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context["UNet3D"]["length_3D"],
            stride=context["UNet3D"]["stride_3D"])
    else:
        ds = imed_loader.MRI2DSegmentationDataset(
            filename_pairs,
            slice_axis=slice_axis,
            cache=True,
            transform=tranform_lst,
            slice_filter_fn=SliceFilter(
                **loader_params["slice_filter_params"]))
        ds.load_filenames()

    if kernel_3D:
        print("\nLoaded {} {} volumes of shape {}.".format(
            len(ds), loader_params['slice_axis'],
            context['UNet3D']['length_3D']))
    else:
        print("\nLoaded {} {} slices.".format(len(ds),
                                              loader_params['slice_axis']))

    # Data Loader
    data_loader = DataLoader(
        ds,
        batch_size=context["training_parameters"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    # MODEL
    if fname_model.endswith('.pt'):
        model = torch.load(fname_model, map_location=device)
        # Inference time
        model.eval()

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, volume, weight_matrix = False, None, None
    for i_batch, batch in enumerate(data_loader):
        with torch.no_grad():
            img = cuda(batch['input'], cuda_available=cuda_available)
            preds = model(img) if fname_model.endswith(
                '.pt') else onnx_inference(fname_model, img)
            preds = preds.cpu()

        # Set datatype to gt since prediction should be processed the same way as gt
        for modality in batch['input_metadata']:
            modality[0]['data_type'] = 'gt'

        # Reconstruct 3D object
        for i_slice in range(len(preds)):
            if "bounding_box" in batch['input_metadata'][i_slice][0]:
                imed_obj_detect.adjust_undo_transforms(
                    undo_transforms.transforms, batch, i_slice)

            if kernel_3D:
                batch['gt_metadata'] = batch['input_metadata']
                preds_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix)
                preds_list = [np.array(preds_undo)]
            else:
                # undo transformations
                preds_i_undo, metadata_idx = undo_transforms(
                    preds[i_slice],
                    batch["input_metadata"][i_slice],
                    data_type='gt')

                # Add new segmented slice to preds_list
                preds_list.append(np.array(preds_i_undo))
                # Store the slice index of preds_i_undo in the original 3D image
                slice_idx_list.append(
                    int(batch['input_metadata'][i_slice][0]['slice_index']))

            # If last batch and last sample of this batch, then reconstruct 3D object
            if (i_batch == len(data_loader) - 1
                    and i_slice == len(batch['gt']) - 1) or last_sample_bool:
                pred_nib = pred_to_nib(data_lst=preds_list,
                                       fname_ref=fname_image,
                                       fname_out=None,
                                       z_lst=slice_idx_list,
                                       slice_axis=slice_axis,
                                       kernel_dim='3d' if kernel_3D else '2d',
                                       debug=False,
                                       bin_thr=-1)

    return pred_nib