Esempio n. 1
0
def apply_preprocessing_transforms(transforms,
                                   seg_pair,
                                   roi_pair=None) -> Tuple[dict, dict]:
    """
    Applies preprocessing transforms to segmentation pair (input, gt and metadata).

    Args:
        transforms (Compose): Preprocessing transforms.
        seg_pair (dict): Segmentation pair containing input and gt.
        roi_pair (dict): Segementation pair containing input and roi.

    Returns:
        tuple: Segmentation pair and roi pair.
    """
    if transforms is None:
        return (seg_pair, roi_pair)

    metadata_input = seg_pair['input_metadata']
    if roi_pair is not None:
        stack_roi, metadata_roi = transforms(sample=roi_pair["gt"],
                                             metadata=roi_pair['gt_metadata'],
                                             data_type="roi",
                                             preprocessing=True)
        metadata_input = imed_loader_utils.update_metadata(
            metadata_roi, metadata_input)
    # Run transforms on images
    stack_input, metadata_input = transforms(sample=seg_pair["input"],
                                             metadata=metadata_input,
                                             data_type="im",
                                             preprocessing=True)
    # Run transforms on images
    metadata_gt = imed_loader_utils.update_metadata(metadata_input,
                                                    seg_pair['gt_metadata'])
    stack_gt, metadata_gt = transforms(sample=seg_pair["gt"],
                                       metadata=metadata_gt,
                                       data_type="gt",
                                       preprocessing=True)

    seg_pair = {
        'input': stack_input,
        'gt': stack_gt,
        MetadataKW.INPUT_METADATA: metadata_input,
        MetadataKW.GT_METADATA: metadata_gt
    }

    if roi_pair is not None and len(roi_pair['gt']):
        roi_pair = {
            'input': stack_input,
            'gt': stack_roi,
            MetadataKW.INPUT_METADATA: metadata_input,
            MetadataKW.GT_METADATA: metadata_roi
        }
    return (seg_pair, roi_pair)
Esempio n. 2
0
 def wrapper(self, sample, metadata):
     if isinstance(sample, list):
         list_data, list_metadata = [], []
         for s_cur, m_cur in zip(sample, metadata):
             if len(list_metadata) > 0:
                 imed_loader_utils.update_metadata([list_metadata[-1]],
                                                   [m_cur])
             # Run function for each sample of the list
             data_cur, metadata_cur = wrapped(self, s_cur, m_cur)
             list_data.append(data_cur)
             list_metadata.append(metadata_cur)
         return list_data, list_metadata
     # If sample is None, then return a pair (None, None)
     if sample is None:
         return None, None
     else:
         return wrapped(self, sample, metadata)
    def __getitem__(self, index):
        """Return the specific processed data corresponding to index (input, ground truth, roi and metadata).

        Args:
            index (int): Slice index.
        """

        # copy.deepcopy is used to have different coordinates for reconstruction for a given handler with patch,
        # to allow a different rater at each iteration of training, and to clean transforms params from previous
        # transforms i.e. remove params from previous iterations so that the coming transforms are different
        if self.is_2d_patch:
            coord = self.indexes[index]
            if self.disk_cache:
                with self.handlers[coord['handler_index']].open(
                        mode="rb") as f:
                    seg_pair_slice, roi_pair_slice = pickle.load(f)
            else:
                seg_pair_slice, roi_pair_slice = copy.deepcopy(
                    self.handlers[coord['handler_index']])
        else:
            if self.disk_cache:
                with self.indexes[index].open(mode="rb") as f:
                    seg_pair_slice, roi_pair_slice = pickle.load(f)
            else:
                seg_pair_slice, roi_pair_slice = copy.deepcopy(
                    self.indexes[index])

        # In case multiple raters
        if seg_pair_slice['gt'] and isinstance(seg_pair_slice['gt'][0], list):
            # Randomly pick a rater
            idx_rater = random.randint(0, len(seg_pair_slice['gt'][0]) - 1)
            # Use it as ground truth for this iteration
            # Note: in case of multi-class: the same rater is used across classes
            for idx_class in range(len(seg_pair_slice['gt'])):
                seg_pair_slice['gt'][idx_class] = seg_pair_slice['gt'][
                    idx_class][idx_rater]
                seg_pair_slice['gt_metadata'][idx_class] = seg_pair_slice[
                    'gt_metadata'][idx_class][idx_rater]

        metadata_input = seg_pair_slice['input_metadata'] if seg_pair_slice[
            'input_metadata'] is not None else []
        metadata_roi = roi_pair_slice['gt_metadata'] if roi_pair_slice[
            'gt_metadata'] is not None else []
        metadata_gt = seg_pair_slice['gt_metadata'] if seg_pair_slice[
            'gt_metadata'] is not None else []

        if self.is_2d_patch:
            stack_roi, metadata_roi = None, None
        else:
            # Set coordinates to the slices full size
            coord = {}
            coord['x_min'], coord['x_max'] = 0, seg_pair_slice["input"][
                0].shape[0]
            coord['y_min'], coord['y_max'] = 0, seg_pair_slice["input"][
                0].shape[1]

            # Run transforms on ROI
            # ROI goes first because params of ROICrop are needed for the followings
            stack_roi, metadata_roi = self.transform(
                sample=roi_pair_slice["gt"],
                metadata=metadata_roi,
                data_type="roi")
            # Update metadata_input with metadata_roi
            metadata_input = imed_loader_utils.update_metadata(
                metadata_roi, metadata_input)

        # Add coordinates of slices or patches to input metadata
        for metadata in metadata_input:
            metadata['coord'] = [
                coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"]
            ]

        # Extract image and gt slices or patches from coordinates
        stack_input = np.asarray(
            seg_pair_slice["input"])[:, coord['x_min']:coord['x_max'],
                                     coord['y_min']:coord['y_max']]
        if seg_pair_slice["gt"]:
            stack_gt = np.asarray(
                seg_pair_slice["gt"])[:, coord['x_min']:coord['x_max'],
                                      coord['y_min']:coord['y_max']]
        else:
            stack_gt = []

        # Run transforms on image slices or patches
        stack_input, metadata_input = self.transform(sample=list(stack_input),
                                                     metadata=metadata_input,
                                                     data_type="im")
        # Update metadata_gt with metadata_input
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)
        if self.task == "segmentation":
            # Run transforms on gt slices or patches
            stack_gt, metadata_gt = self.transform(sample=list(stack_gt),
                                                   metadata=metadata_gt,
                                                   data_type="gt")
            # Make sure stack_gt is binarized
            if stack_gt is not None and not self.soft_gt:
                stack_gt = imed_postpro.threshold_predictions(stack_gt,
                                                              thr=0.5).astype(
                                                                  np.uint8)
        else:
            # Force no transformation on labels for classification task
            # stack_gt is a tensor of size 1x1, values: 0 or 1
            # "expand(1)" is necessary to be compatible with segmentation convention: n_labelxhxwxd
            stack_gt = torch.from_numpy(seg_pair_slice["gt"][0]).expand(1)

        data_dict = {
            'input': stack_input,
            'gt': stack_gt,
            'roi': stack_roi,
            'input_metadata': metadata_input,
            'gt_metadata': metadata_gt,
            'roi_metadata': metadata_roi
        }

        # Input-level dropout to train with missing modalities
        if self.is_input_dropout:
            data_dict = dropout_input(data_dict)

        return data_dict
Esempio n. 4
0
    def __getitem__(self, index):
        """Get samples.

        Warning: For now, this method only supports one gt / roi.

        Args:
            index (int): Sample index.

        Returns:
            dict: Dictionary containing image and label tensors as well as metadata.
        """
        line = self.dataframe.iloc[index]
        # For HeMIS strategy. Otherwise the values of the matrix dont change anything.
        missing_modalities = self.cst_matrix[index]

        input_metadata = []
        input_tensors = []

        # Inputs
        with h5py.File(self.path_hdf5, "r") as f:
            for i, ct in enumerate(self.cst_lst):
                if self.status[ct]:
                    input_tensor = line[ct] * missing_modalities[i]
                else:
                    input_tensor = f[line[ct]][
                        line['Slices']] * missing_modalities[i]

                input_tensors.append(input_tensor)
                # input Metadata
                metadata = imed_loader_utils.SampleMetadata({
                    key: value
                    for key, value in f['{}/inputs/{}'.format(
                        line['Subjects'], ct)].attrs.items()
                })
                metadata['slice_index'] = line["Slices"]
                metadata['missing_mod'] = missing_modalities
                metadata['crop_params'] = {}
                input_metadata.append(metadata)

            # GT
            gt_img = []
            gt_metadata = []
            for idx, gt in enumerate(self.gt_lst):
                if self.status['gt/' + gt]:
                    gt_data = line['gt/' + gt]
                else:
                    gt_data = f[line['gt/' + gt]][line['Slices']]

                gt_data = gt_data.astype(np.uint8)
                gt_img.append(gt_data)
                gt_metadata.append(
                    imed_loader_utils.SampleMetadata({
                        key: value
                        for key, value in f[line['gt/' + gt]].attrs.items()
                    }))
                gt_metadata[idx]['crop_params'] = {}

            # ROI
            roi_img = []
            roi_metadata = []
            if self.roi_lst:
                if self.status['roi/' + self.roi_lst[0]]:
                    roi_data = line['roi/' + self.roi_lst[0]]
                else:
                    roi_data = f[line['roi/' +
                                      self.roi_lst[0]]][line['Slices']]

                roi_data = roi_data.astype(np.uint8)
                roi_img.append(roi_data)

                roi_metadata.append(
                    imed_loader_utils.SampleMetadata({
                        key: value
                        for key, value in f[line[
                            'roi/' + self.roi_lst[0]]].attrs.items()
                    }))
                roi_metadata[0]['crop_params'] = {}

            # Run transforms on ROI
            # ROI goes first because params of ROICrop are needed for the followings
            stack_roi, metadata_roi = self.transform(sample=roi_img,
                                                     metadata=roi_metadata,
                                                     data_type="roi")
            # Update metadata_input with metadata_roi
            metadata_input = imed_loader_utils.update_metadata(
                metadata_roi, input_metadata)

            # Run transforms on images
            stack_input, metadata_input = self.transform(
                sample=input_tensors, metadata=metadata_input, data_type="im")
            # Update metadata_input with metadata_roi
            metadata_gt = imed_loader_utils.update_metadata(
                metadata_input, gt_metadata)

            # Run transforms on images
            stack_gt, metadata_gt = self.transform(sample=gt_img,
                                                   metadata=metadata_gt,
                                                   data_type="gt")
            data_dict = {
                'input': stack_input,
                'gt': stack_gt,
                'roi': stack_roi,
                'input_metadata': metadata_input,
                'gt_metadata': metadata_gt,
                'roi_metadata': metadata_roi
            }

            return data_dict
Esempio n. 5
0
    def __getitem__(self, index):
        """Return the specific index pair subvolume (input, ground truth).

        Args:
            index (int): Subvolume index.
        """
        coord = self.indexes[index]
        seg_pair, _ = self.handlers[coord['handler_index']]

        # Clean transforms params from previous transforms
        # i.e. remove params from previous iterations so that the coming transforms are different
        # Use copy to have different coordinates for reconstruction for a given handler
        metadata_input = imed_loader_utils.clean_metadata(
            copy.deepcopy(seg_pair['input_metadata']))
        metadata_gt = imed_loader_utils.clean_metadata(
            copy.deepcopy(seg_pair['gt_metadata']))

        # Run transforms on images
        stack_input, metadata_input = self.transform(sample=seg_pair['input'],
                                                     metadata=metadata_input,
                                                     data_type="im")
        # Update metadata_gt with metadata_input
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)

        # Run transforms on images
        stack_gt, metadata_gt = self.transform(sample=seg_pair['gt'],
                                               metadata=metadata_gt,
                                               data_type="gt")
        # Make sure stack_gt is binarized
        if stack_gt is not None and not self.soft_gt:
            stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5)

        shape_x = coord["x_max"] - coord["x_min"]
        shape_y = coord["y_max"] - coord["y_min"]
        shape_z = coord["z_max"] - coord["z_min"]

        # add coordinates to metadata to reconstruct volume
        for metadata in metadata_input:
            metadata['coord'] = [
                coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"],
                coord["z_min"], coord["z_max"]
            ]

        subvolumes = {
            'input':
            torch.zeros(stack_input.shape[0], shape_x, shape_y, shape_z),
            'gt':
            torch.zeros(stack_gt.shape[0], shape_x, shape_y, shape_z)
            if stack_gt is not None else None,
            'input_metadata':
            metadata_input,
            'gt_metadata':
            metadata_gt
        }

        for _ in range(len(stack_input)):
            subvolumes['input'] = stack_input[:, coord['x_min']:coord['x_max'],
                                              coord['y_min']:coord['y_max'],
                                              coord['z_min']:coord['z_max']]

        if stack_gt is not None:
            for _ in range(len(stack_gt)):
                subvolumes['gt'] = stack_gt[:, coord['x_min']:coord['x_max'],
                                            coord['y_min']:coord['y_max'],
                                            coord['z_min']:coord['z_max']]

        return subvolumes
Esempio n. 6
0
    def __getitem__(self, index):
        """Return the specific processed data corresponding to index (input, ground truth, roi and metadata).

        Args:
            index (int): Slice index.
        """
        seg_pair_slice, roi_pair_slice = self.indexes[index]

        # Clean transforms params from previous transforms
        # i.e. remove params from previous iterations so that the coming transforms are different
        metadata_input = imed_loader_utils.clean_metadata(
            seg_pair_slice['input_metadata'])
        metadata_roi = imed_loader_utils.clean_metadata(
            roi_pair_slice['gt_metadata'])
        metadata_gt = imed_loader_utils.clean_metadata(
            seg_pair_slice['gt_metadata'])

        # Run transforms on ROI
        # ROI goes first because params of ROICrop are needed for the followings
        stack_roi, metadata_roi = self.transform(sample=roi_pair_slice["gt"],
                                                 metadata=metadata_roi,
                                                 data_type="roi")

        # Update metadata_input with metadata_roi
        metadata_input = imed_loader_utils.update_metadata(
            metadata_roi, metadata_input)

        # Run transforms on images
        stack_input, metadata_input = self.transform(
            sample=seg_pair_slice["input"],
            metadata=metadata_input,
            data_type="im")

        # Update metadata_input with metadata_roi
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)

        if self.task == "segmentation":
            # Run transforms on images
            stack_gt, metadata_gt = self.transform(sample=seg_pair_slice["gt"],
                                                   metadata=metadata_gt,
                                                   data_type="gt")
            # Make sure stack_gt is binarized
            if stack_gt is not None and not self.soft_gt:
                stack_gt = imed_postpro.threshold_predictions(stack_gt,
                                                              thr=0.5)

        else:
            # Force no transformation on labels for classification task
            # stack_gt is a tensor of size 1x1, values: 0 or 1
            # "expand(1)" is necessary to be compatible with segmentation convention: n_labelxhxwxd
            stack_gt = torch.from_numpy(seg_pair_slice["gt"][0]).expand(1)

        data_dict = {
            'input': stack_input,
            'gt': stack_gt,
            'roi': stack_roi,
            'input_metadata': metadata_input,
            'gt_metadata': metadata_gt,
            'roi_metadata': metadata_roi
        }

        return data_dict