Пример #1
0
    def _load_filenames(self):
        """Load preprocessed pair data (input and gt) in handler."""
        for input_filename, gt_filename, roi_filename, metadata in self.filename_pairs:
            segpair = SegmentationPair(input_filename,
                                       gt_filename,
                                       metadata=metadata,
                                       slice_axis=self.slice_axis,
                                       soft_gt=self.soft_gt)
            input_data, gt_data = segpair.get_pair_data()
            metadata = segpair.get_pair_metadata()
            seg_pair = {
                'input': input_data,
                'gt': gt_data,
                'input_metadata': metadata['input_metadata'],
                'gt_metadata': metadata['gt_metadata']
            }

            self.has_bounding_box = imed_obj_detect.verify_metadata(
                seg_pair, self.has_bounding_box)
            if self.has_bounding_box:
                self.prepro_transforms = imed_obj_detect.adjust_transforms(
                    self.prepro_transforms,
                    seg_pair,
                    length=self.length,
                    stride=self.stride)
            seg_pair, roi_pair = imed_transforms.apply_preprocessing_transforms(
                self.prepro_transforms, seg_pair=seg_pair)

            for metadata in seg_pair['input_metadata']:
                metadata['index_shape'] = seg_pair['input'][0].shape
            self.handlers.append((seg_pair, roi_pair))
Пример #2
0
    def load_filenames(self):
        """Load preprocessed pair data (input and gt) in handler."""
        for input_filenames, gt_filenames, roi_filename, metadata in self.filename_pairs:
            roi_pair = SegmentationPair(input_filenames, roi_filename, metadata=metadata, slice_axis=self.slice_axis,
                                        cache=self.cache, prepro_transforms=self.prepro_transforms)

            seg_pair = SegmentationPair(input_filenames, gt_filenames, metadata=metadata, slice_axis=self.slice_axis,
                                        cache=self.cache, prepro_transforms=self.prepro_transforms,
                                        soft_gt=self.soft_gt)

            input_data_shape, _ = seg_pair.get_pair_shapes()

            for idx_pair_slice in range(input_data_shape[-1]):
                slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice, gt_type=self.task)
                self.has_bounding_box = imed_obj_detect.verify_metadata(slice_seg_pair, self.has_bounding_box)
                if self.has_bounding_box:
                    self.prepro_transforms = imed_obj_detect.adjust_transforms(self.prepro_transforms, slice_seg_pair)

                if self.slice_filter_fn and not self.slice_filter_fn(slice_seg_pair):
                    continue

                # Note: we force here gt_type=segmentation since ROI slice is needed to Crop the image
                slice_roi_pair = roi_pair.get_pair_slice(idx_pair_slice, gt_type="segmentation")

                if self.slice_filter_roi and imed_loader_utils.filter_roi(slice_roi_pair['gt'], self.roi_thr):
                    continue

                item = imed_transforms.apply_preprocessing_transforms(self.prepro_transforms,
                                                                      slice_seg_pair,
                                                                      slice_roi_pair)
                self.indexes.append(item)
Пример #3
0
    def _slice_seg_pair(self, idx_pair_slice, seg_pair, roi_pair,
                        useful_slices, input_volumes, gt_volume, roi_volume):
        """ Helper function to slice segmentation pair at load time """
        slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice)

        self.has_bounding_box = imed_obj_detect.verify_metadata(
            slice_seg_pair, self.has_bounding_box)
        if self.has_bounding_box:
            imed_obj_detect.adjust_transforms(self.prepro_transforms,
                                              slice_seg_pair)

        # keeping idx of slices with gt
        if self.slice_filter_fn:
            filter_fn_ret_seg = self.slice_filter_fn(slice_seg_pair)
        if self.slice_filter_fn and filter_fn_ret_seg:
            useful_slices.append(idx_pair_slice)

        roi_pair_slice = roi_pair.get_pair_slice(idx_pair_slice)
        slice_seg_pair, roi_pair_slice = imed_transforms.apply_preprocessing_transforms(
            self.prepro_transforms, slice_seg_pair, roi_pair_slice)

        input_volumes.append(slice_seg_pair["input"][0])

        # Handle unlabeled data
        if not len(slice_seg_pair["gt"]):
            gt_volume = []
        else:
            gt_volume.append(
                (slice_seg_pair["gt"][0] * 255).astype(np.uint8) / 255.)

        # Handle data with no ROI provided
        if not len(roi_pair_slice["gt"]):
            roi_volume = []
        else:
            roi_volume.append(
                (roi_pair_slice["gt"][0] * 255).astype(np.uint8) / 255.)

        return slice_seg_pair, roi_pair_slice
    def load_filenames(self):
        """Load preprocessed pair data (input and gt) in handler."""
        for input_filenames, gt_filenames, roi_filename, metadata in self.filename_pairs:
            roi_pair = SegmentationPair(
                input_filenames,
                roi_filename,
                metadata=metadata,
                slice_axis=self.slice_axis,
                cache=self.cache,
                prepro_transforms=self.prepro_transforms)

            seg_pair = SegmentationPair(
                input_filenames,
                gt_filenames,
                metadata=metadata,
                slice_axis=self.slice_axis,
                cache=self.cache,
                prepro_transforms=self.prepro_transforms,
                soft_gt=self.soft_gt)

            input_data_shape, _ = seg_pair.get_pair_shapes()

            path_temp = Path(create_temp_directory())

            for idx_pair_slice in range(input_data_shape[-1]):
                slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice,
                                                         gt_type=self.task)
                self.has_bounding_box = imed_obj_detect.verify_metadata(
                    slice_seg_pair, self.has_bounding_box)

                if self.has_bounding_box:
                    self.prepro_transforms = imed_obj_detect.adjust_transforms(
                        self.prepro_transforms, slice_seg_pair)

                if self.slice_filter_fn and not self.slice_filter_fn(
                        slice_seg_pair):
                    continue

                # Note: we force here gt_type=segmentation since ROI slice is needed to Crop the image
                slice_roi_pair = roi_pair.get_pair_slice(
                    idx_pair_slice, gt_type="segmentation")

                if self.slice_filter_roi and imed_loader_utils.filter_roi(
                        slice_roi_pair['gt'], self.roi_thr):
                    continue

                item: Tuple[
                    dict,
                    dict] = imed_transforms.apply_preprocessing_transforms(
                        self.prepro_transforms, slice_seg_pair, slice_roi_pair)
                # Run once code to keep track if disk cache is used
                if self.disk_cache is None:
                    self.determine_cache_need(item, input_data_shape[-1])

                # If is_2d_patch, create handlers list for indexing patch
                if self.is_2d_patch:
                    for metadata in item[0][MetadataKW.INPUT_METADATA]:
                        metadata[
                            MetadataKW.INDEX_SHAPE] = item[0]['input'][0].shape
                    if self.disk_cache:
                        path_item = path_temp / f"item_{get_timestamp()}.pkl"
                        with path_item.open(mode="wb") as f:
                            pickle.dump(item, f)
                        self.handlers.append((path_item))
                    else:
                        self.handlers.append((item))
                # else, append the whole slice to self.indexes
                else:

                    if self.disk_cache:
                        path_item = path_temp / f"item_{get_timestamp()}.pkl"
                        with path_item.open(mode="wb") as f:
                            pickle.dump(item, f)
                        self.indexes.append(path_item)
                    else:
                        self.indexes.append(item)

        # If is_2d_patch, prepare indices of patches
        if self.is_2d_patch:
            self.prepare_indices()
Пример #5
0
    def _load_filenames(self):
        """Load preprocessed pair data (input and gt) in handler."""
        for subject_id, input_filename, gt_filename, roi_filename, metadata in self.filename_pairs:
            # Creating/ getting the subject group
            if str(subject_id) in self.hdf5_file.keys():
                grp = self.hdf5_file[str(subject_id)]
            else:
                grp = self.hdf5_file.create_group(str(subject_id))

            roi_pair = imed_loader.SegmentationPair(input_filename,
                                                    roi_filename,
                                                    metadata=metadata,
                                                    slice_axis=self.slice_axis,
                                                    cache=False,
                                                    soft_gt=self.soft_gt)

            seg_pair = imed_loader.SegmentationPair(input_filename,
                                                    gt_filename,
                                                    metadata=metadata,
                                                    slice_axis=self.slice_axis,
                                                    cache=False,
                                                    soft_gt=self.soft_gt)
            print("gt filename", gt_filename)
            input_data_shape, _ = seg_pair.get_pair_shapes()

            useful_slices = []
            input_volumes = []
            gt_volume = []
            roi_volume = []

            for idx_pair_slice in range(input_data_shape[-1]):

                slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice)

                self.has_bounding_box = imed_obj_detect.verify_metadata(
                    slice_seg_pair, self.has_bounding_box)
                if self.has_bounding_box:
                    imed_obj_detect.adjust_transforms(self.prepro_transforms,
                                                      slice_seg_pair)

                # keeping idx of slices with gt
                if self.slice_filter_fn:
                    filter_fn_ret_seg = self.slice_filter_fn(slice_seg_pair)
                if self.slice_filter_fn and filter_fn_ret_seg:
                    useful_slices.append(idx_pair_slice)

                roi_pair_slice = roi_pair.get_pair_slice(idx_pair_slice)
                slice_seg_pair, roi_pair_slice = imed_transforms.apply_preprocessing_transforms(
                    self.prepro_transforms, slice_seg_pair, roi_pair_slice)

                input_volumes.append(slice_seg_pair["input"][0])

                # Handle unlabeled data
                if not len(slice_seg_pair["gt"]):
                    gt_volume = []
                else:
                    gt_volume.append(
                        (slice_seg_pair["gt"][0] * 255).astype(np.uint8) /
                        255.)

                # Handle data with no ROI provided
                if not len(roi_pair_slice["gt"]):
                    roi_volume = []
                else:
                    roi_volume.append(
                        (roi_pair_slice["gt"][0] * 255).astype(np.uint8) /
                        255.)

            # Getting metadata using the one from the last slice
            input_metadata = slice_seg_pair['input_metadata'][0]
            gt_metadata = slice_seg_pair['gt_metadata'][0]
            roi_metadata = roi_pair_slice['input_metadata'][0]

            if grp.attrs.__contains__('slices'):
                grp.attrs['slices'] = list(
                    set(np.concatenate((grp.attrs['slices'], useful_slices))))
            else:
                grp.attrs['slices'] = useful_slices

            # Creating datasets and metadata
            contrast = input_metadata['contrast']
            # Inputs
            print(len(input_volumes))
            print("grp= ", str(subject_id))
            key = "inputs/{}".format(contrast)
            print("key = ", key)
            if len(input_volumes) < 1:
                print("list empty")
                continue
            grp.create_dataset(key, data=input_volumes)
            # Sub-group metadata
            if grp['inputs'].attrs.__contains__('contrast'):
                attr = grp['inputs'].attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp['inputs'].attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp['inputs'].attrs.create('contrast', [contrast],
                                           dtype=self.dt)

            # dataset metadata
            grp[key].attrs['input_filenames'] = input_metadata[
                'input_filenames']
            grp[key].attrs['data_type'] = input_metadata['data_type']

            if "zooms" in input_metadata.keys():
                grp[key].attrs["zooms"] = input_metadata['zooms']
            if "data_shape" in input_metadata.keys():
                grp[key].attrs["data_shape"] = input_metadata['data_shape']
            if "bounding_box" in input_metadata.keys():
                grp[key].attrs["bounding_box"] = input_metadata['bounding_box']

            # GT
            key = "gt/{}".format(contrast)
            grp.create_dataset(key, data=gt_volume)
            # Sub-group metadata
            if grp['gt'].attrs.__contains__('contrast'):
                attr = grp['gt'].attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp['gt'].attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp['gt'].attrs.create('contrast', [contrast], dtype=self.dt)

            # dataset metadata
            grp[key].attrs['gt_filenames'] = input_metadata['gt_filenames']
            grp[key].attrs['data_type'] = gt_metadata['data_type']

            if "zooms" in gt_metadata.keys():
                grp[key].attrs["zooms"] = gt_metadata['zooms']
            if "data_shape" in gt_metadata.keys():
                grp[key].attrs["data_shape"] = gt_metadata['data_shape']
            if gt_metadata['bounding_box'] is not None:
                grp[key].attrs["bounding_box"] = gt_metadata['bounding_box']

            # ROI
            key = "roi/{}".format(contrast)
            grp.create_dataset(key, data=roi_volume)
            # Sub-group metadata
            if grp['roi'].attrs.__contains__('contrast'):
                attr = grp['roi'].attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp['roi'].attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp['roi'].attrs.create('contrast', [contrast], dtype=self.dt)

            # dataset metadata
            grp[key].attrs['roi_filename'] = roi_metadata['gt_filenames']
            grp[key].attrs['data_type'] = roi_metadata['data_type']

            if "zooms" in roi_metadata.keys():
                grp[key].attrs["zooms"] = roi_metadata['zooms']
            if "data_shape" in roi_metadata.keys():
                grp[key].attrs["data_shape"] = roi_metadata['data_shape']

            # Adding contrast to group metadata
            if grp.attrs.__contains__('contrast'):
                attr = grp.attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp.attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp.attrs.create('contrast', [contrast], dtype=self.dt)