예제 #1
0
파일: dataset.py 프로젝트: auger1/MMWHS
 def datasources(self):
     """
     Returns the data sources that load data.
     {
     'image:' CachedImageDataSource that loads the image files.
     'landmarks:' LandmarkDataSource that loads the landmark coordinates.
     'mask:' CachedImageDataSource that loads the groundtruth labels.
     }
     :return: A dict of data sources.
     """
     preprocessing = lambda image: gaussian_sitk(image, self.
                                                 input_gaussian_sigma)
     image_datasource = CachedImageDataSource(self.image_base_folder,
                                              '',
                                              '_image',
                                              '.mha',
                                              preprocessing=preprocessing)
     landmark_datasource = LandmarkDataSource(self.landmarks_file, 1,
                                              self.dim)
     mask_datasource = CachedImageDataSource(self.image_base_folder,
                                             '',
                                             '_label_sorted',
                                             '.mha',
                                             sitk_pixel_type=sitk.sitkUInt8)
     return {
         'image': image_datasource,
         'landmarks': landmark_datasource,
         'mask': mask_datasource
     }
 def datasources_single_frame(self, iterator):
     """
     Returns the data sources that load data for a single frame.
     {
     'image:' CachedImageDataSource that loads the image files.
     'merged:' CachedImageDataSource that loads the segmentation/tracking label files.
     'seg_loss_mask:' CachedImageDataSource that loads the mask, which defines, where the region of the image, where the semgentation is valid.
                      This is needed, as the images are not segmented in the border regions of the image.
     'has_complete_seg': LabelDatasource that loads a label file, which contains '1', if the current frame has
                         a complete segmentation (all cells are segmented), or '0' if not.
     }
     :return: A dict of data sources.
     """
     datasources_dict = {}
     # image data source loads input image.
     datasources_dict['image'] = CachedImageDataSource(
         self.image_base_folder,
         file_ext='.mha',
         set_identity_spacing=True,
         preprocessing=self.image_postprocessing,
         sitk_pixel_type=sitk.sitkUInt16,
         id_dict_preprocessing=lambda x:
         {'image_id': x['video_id'] + '/t' + x['frame_id']},
         cache_maxsize=1024,
         name='image',
         parents=[iterator])
     if self.load_merged:
         datasources_dict['merged'] = CachedImageDataSource(
             self.image_base_folder,
             file_ext='.mha',
             set_identity_spacing=True,
             sitk_pixel_type=sitk.sitkUInt16,
             id_dict_preprocessing=lambda x:
             {'image_id': x['video_id'] + '_GT/MERGED/' + x['frame_id']},
             cache_maxsize=1024,
             name='merged',
             parents=[iterator])
     if self.load_seg_loss_mask:
         datasources_dict['seg_loss_mask'] = CachedImageDataSource(
             self.image_base_folder,
             file_ext='.mha',
             set_identity_spacing=True,
             sitk_pixel_type=sitk.sitkInt8,
             return_none_if_not_found=False,
             id_dict_preprocessing=lambda x:
             {'image_id': x['video_id'] + '_GT/seg_loss_mask'},
             cache_maxsize=1024,
             name='seg_loss_mask',
             parents=[iterator])
     if self.load_has_complete_seg:
         datasources_dict['has_complete_seg'] = LabelDatasource(
             self.seg_mask_file_name,
             id_dict_preprocessing=lambda x:
             {'image_id': x['video_id'] + '/' + x['frame_id']},
             name='has_complete_seg',
             parents=[iterator])
     return datasources_dict
예제 #3
0
 def data_sources(self, cached, image_extension='.mha'):
     """
     Returns the data sources that load data.
     {
     'image_datasource:' ImageDataSource that loads the image files.
     'landmarks_datasource:' LandmarkDataSource that loads the landmark coordinates.
     }
     :param cached: If true, use a CachedImageDataSource instead of an ImageDataSource.
     :param image_extension: The image extension of the input data.
     :return: A dict of data sources.
     """
     if cached:
         image_datasource = CachedImageDataSource(
             self.image_base_folder,
             '',
             '',
             image_extension,
             preprocessing=reduce_dimension,
             set_identity_spacing=True,
             cache_maxsize=16384)
     else:
         image_datasource = ImageDataSource(self.image_base_folder,
                                            '',
                                            '',
                                            image_extension,
                                            preprocessing=reduce_dimension,
                                            set_identity_spacing=True)
     landmarks_datasource = LandmarkDataSource(self.point_list_file_name,
                                               self.num_landmarks, self.dim)
     return {
         'image_datasource': image_datasource,
         'landmarks_datasource': landmarks_datasource
     }
예제 #4
0
 def data_sources(self,
                  iterator,
                  cached,
                  use_landmarks=True,
                  image_extension='.nii.gz'):
     """
     Returns the data sources that load data.
     {
     'image_datasource:' ImageDataSource that loads the image files.
     'landmarks_datasource:' LandmarkDataSource that loads the landmark coordinates.
     }
     :param iterator: The iterator node object.
     :param cached: If True, use CachedImageDataSource instead of ImageDataSource.
     :param use_landmarks: If True, create a landmarks datasource.
     :param image_extension: The extension of the image files.
     :return: A dict of data sources.
     """
     if self.smoothing_sigma is not None and self.smoothing_sigma > 0:
         preprocessing = lambda x: sitk.Cast(
             gaussian(x, self.smoothing_sigma), sitk.sitkInt16)
     else:
         preprocessing = None
     if cached:
         image_datasource = CachedImageDataSource(
             self.image_base_folder,
             '',
             '',
             image_extension,
             preprocessing=preprocessing,
             cache_maxsize=25000,
             name='image_datasource',
             parents=[iterator])
     else:
         image_datasource = ImageDataSource(self.image_base_folder,
                                            '',
                                            '',
                                            image_extension,
                                            preprocessing=preprocessing,
                                            name='image_datasource',
                                            parents=[iterator])
     data_sources_dict = {}
     data_sources_dict['image_datasource'] = image_datasource
     if use_landmarks:
         landmark_datasource = LandmarkDataSource(
             self.point_list_file_name,
             self.num_landmarks,
             self.dim,
             name='landmarks_datasource',
             parents=[iterator])
         data_sources_dict['landmarks_datasource'] = landmark_datasource
     return data_sources_dict
예제 #5
0
    def datasources_single_frame(self, iterator):
        """
        Returns the data sources that load data for a single frame.
        {
        'image:' CachedImageDataSource that loads the image files.
        }
        :return: A dict of data sources.
        """
        datasources_dict = {}
        # image data source loads input image.
        datasources_dict['image'] = CachedImageDataSource(self.image_base_folder, 't', '', '.tif',
                                                          set_identity_spacing=True,
                                                          preprocessing=self.image_postprocessing,
                                                          sitk_pixel_type=sitk.sitkUInt16,
                                                          cache_maxsize=512,
                                                          name='image',
                                                          parents=[iterator])

        return datasources_dict
예제 #6
0
 def data_sources(self, cached, iterator, image_extension='.nii.gz'):
     """
     Returns the data sources that load data.
     {
     'image_datasource:' ImageDataSource that loads the image files.
     'landmarks_datasource:' LandmarkDataSource that loads the landmark coordinates.
     'junior_landmarks_datasource',
     'senior_landmarks_datasource',
     'challenge_landmarks_datasource',
     'mean_landmarks_datasource',
     'random_landmarks_datasource': other LandmarkDataSources for different groundtruths.
     }
     :param cached: If true, use a CachedImageDataSource instead of an ImageDataSource.
     :param iterator: The iterator that is used as the parent node.
     :param image_extension: The image extension of the input data.
     :return: A dict of data sources.
     """
     if cached:
         image_datasource = CachedImageDataSource(
             self.image_base_folder,
             '',
             '',
             image_extension,
             preprocessing=self.image_preprocessing,
             set_identity_spacing=False,
             cache_maxsize=16384,
             parents=[iterator],
             name='image_datasource')
     else:
         image_datasource = ImageDataSource(
             self.image_base_folder,
             '',
             '',
             image_extension,
             preprocessing=self.image_preprocessing,
             set_identity_spacing=False,
             parents=[iterator],
             name='image_datasource')
     landmark_datasources = self.landmark_datasources(iterator)
     return {'image_datasource': image_datasource, **landmark_datasources}