def spatial_transformation_augmented(self, iterator, datasources, image_size):
     """
     The spatial image transformation with random augmentation.
     :param datasources: datasources dict.
     :return: The transformation.
     """
     transformation_list = []
     kwparents = {'image': datasources['image'], 'output_size': image_size}
     if self.translate_to_center_landmarks:
         kwparents['start'] = datasources['landmarks_bb_start']
         kwparents['extent'] = datasources['landmarks_bb_extent']
         transformation_list.append(translation.BoundingBoxCenterToOrigin(self.dim, None, self.image_spacing))
     elif self.generate_single_vertebrae or self.generate_single_vertebrae_heatmap:
         single_landmark = LambdaNode(lambda id_dict, landmarks: [landmarks[int(id_dict['landmark_id'])]],
                                      parents=[iterator, datasources['landmarks']])
         kwparents['landmarks'] = single_landmark
         transformation_list.append(landmark.Center(self.dim, True))
         transformation_list.append(translation.Fixed(self.dim, [0, 20, 0]))
     else:
         transformation_list.append(translation.InputCenterToOrigin(self.dim))
     if self.translate_by_random_factor:
         transformation_list.append(translation.RandomCropBoundingBox(self.dim, None, self.image_spacing))
     #    transformation_list.append(translation.RandomFactorInput(self.dim, [0, 0, 0.5], [0, 0, self.image_spacing[2] * image_size[2]]))
     transformation_list.extend([translation.Random(self.dim, [self.random_translation] * self.dim),
                                 rotation.Random(self.dim, [self.random_rotate] * self.dim),
                                 scale.RandomUniform(self.dim, self.random_scale),
                                 scale.Random(self.dim, [self.random_scale] * self.dim),
                                 flip.Random(self.dim, [0.5 if self.random_flip else 0.0, 0.0, 0.0]),
                                 translation.OriginToOutputCenter(self.dim, None, self.image_spacing),
                                 deformation.Output(self.dim, [6, 6, 6], [self.random_deformation] * self.dim, None, self.image_spacing)
                                 ])
     comp = composite.Composite(self.dim, transformation_list, name='image', kwparents=kwparents)
     return LambdaNode(lambda comp, output_size: sitk.DisplacementFieldTransform(sitk.TransformToDisplacementField(comp, sitk.sitkVectorFloat64, size=output_size, outputSpacing=self.image_spacing)),
                       name='image',
                       kwparents={'comp': comp, 'output_size': image_size})
    def dataset_val(self):
        """
        Returns the validation dataset. No random augmentation is performed.
        :return: The validation dataset.
        """
        if self.cv == 'inference':
            iterator = 'iterator'
        else:
            iterator = self.iterator(self.test_file, False)
        sources = self.datasources(iterator, True, True, self.preprocessing, 2048)
        if self.use_variable_image_size:
            if self.load_spine_bbs:
                image_size = ImageSizeGenerator(self.dim, [None] * 3, self.image_spacing, valid_output_sizes=[[32, 64, 96, 128], [32, 64, 96, 128], [32 + i * 32 for i in range(20)]], name='output_size', kwparents={'extent': sources['landmarks_bb_extent']})
            else:
                image_size = ImageSizeGenerator(self.dim, [None] * 3, self.image_spacing, valid_output_sizes=[[32, 64, 96, 128], [32, 64, 96, 128], [32 + i * 32 for i in range(20)]], name='output_size', kwparents={'image': sources['image']})
            #image_size = ImageSizeGenerator(self.dim, [None] * 3, self.image_spacing, valid_output_sizes=[self.valid_output_sizes_x, self.valid_output_sizes_y, [64 + i * 32 for i in range(20)]], name='output_size', kwparents={'extent': sources['landmarks_bb_extent']})
        else:
            image_size = LambdaNode(lambda: self.image_size, name='output_size')
        reference_transformation = self.spatial_transformation(iterator, sources, image_size)
        generators = self.data_generators(iterator, sources, reference_transformation, self.postprocessing, False, image_size, False)
        generators['image_id'] = LambdaNode(lambda d: np.array(d['image_id']), name='image_id', parents=[iterator])

        return GraphDataset(data_generators=list(generators.values()),
                            data_sources=list(sources.values()),
                            transformations=[reference_transformation],
                            iterator=iterator,
                            debug_image_folder='debug_val' if self.save_debug_images else None)
    def all_generators_post_processing(self,
                                       generators_dict,
                                       random_move=False):
        """
        Function that will be called, after all data generators generated their data. Used to combine the results
        of the individual datagenerators. This function will create a dict as follows:
        {
        'image': generators_dict['image'],
        'instances_bac': Instance image for the background. Combines generators_dict['merged'],
                         generators_dict['seg_loss_mask'], and generators_dict['has_complete_seg'].
        'instances_merged': generators_dict['merged']
        :param generators_dict: The generators_dict that will be generated from the dataset.
        :param random_move: If true, the resulting images will be moved randomly according to the given parameters.
        :return: The final generators_dict of np arrays.
        """
        final_generators_dict = {}
        final_generators_dict['image'] = generators_dict['image']

        if self.create_instances_merged:
            final_generators_dict['instances_merged'] = LambdaNode(
                lambda x: self.instance_image(x, self.instances_datatype),
                name='instances_merged',
                parents=[generators_dict['merged']])

        if self.create_instances_bac:

            def instances_bac(merged, seg_loss_mask, has_complete_seg):
                binary_seg = self.binary_labels(merged)
                seg_loss_mask = self.loss_mask(seg_loss_mask, has_complete_seg)
                instances_bac = (binary_seg[0] +
                                 2 * binary_seg[1]) * seg_loss_mask
                return instances_bac.astype(np.int8)

            final_generators_dict['instances_bac'] = LambdaNode(
                instances_bac,
                name='instances_bac',
                parents=[
                    generators_dict['merged'],
                    generators_dict['seg_loss_mask'],
                    generators_dict['has_complete_seg']
                ])

        if random_move:
            random_move = LambdaNode(self.get_random_move,
                                     name='random_move',
                                     parents=[final_generators_dict['image']])
            for key, value in final_generators_dict.items():
                if key == 'image':
                    f = lambda image, random_move: self.move_volume(
                        image, random_move, 'reflect')
                else:
                    f = lambda image, random_move: self.move_volume(
                        image, random_move, 'constant')
                final_generators_dict[key] = LambdaNode(
                    f, parents=[value, random_move], name=key)

        return final_generators_dict
Example #4
0
    def landmark_datasources(self, iterator):
        """
        Returns all LandmarkDataSources as a dictionary. The entry 'landmark_data_source' is set to the value set in self.landmark_source.
        Other entries: 'junior_landmarks_datasource',
                       'senior_landmarks_datasource',
                       'challenge_landmarks_datasource',
                       'mean_landmarks_datasource',
                       'random_landmarks_datasource'.
        :param iterator: The iterator that is used as the parent node.
        :return: Dictionary of all LandmarkDataSources.
        """
        senior_landmarks_datasource = LandmarkDataSource(
            self.senior_point_list_file_name,
            self.num_landmarks,
            self.dim,
            parents=[iterator],
            name='senior_landmarks_datasource')
        junior_landmarks_datasource = LandmarkDataSource(
            self.junior_point_list_file_name,
            self.num_landmarks,
            self.dim,
            parents=[iterator],
            name='junior_landmarks_datasource')
        challenge_landmarks_datasource = LandmarkDataSource(
            self.challenge_point_list_file_name,
            self.num_landmarks,
            self.dim,
            parents=[iterator],
            name='challenge_landmarks_datasource')
        mean_landmarks_datasource = LambdaNode(
            lambda junior, senior: get_mean_landmark_list(junior, senior),
            parents=[junior_landmarks_datasource, senior_landmarks_datasource],
            name='mean_landmarks_datasource')
        random_landmarks_datasource = LambdaNode(
            lambda junior, senior: senior if bool_bernoulli(0.5) else junior,
            parents=[junior_landmarks_datasource, senior_landmarks_datasource],
            name='random_landmarks_datasource')
        if self.landmark_source == 'senior':
            landmark_datasource = senior_landmarks_datasource
        elif self.landmark_source == 'junior':
            landmark_datasource = junior_landmarks_datasource
        elif self.landmark_source == 'challenge':
            landmark_datasource = challenge_landmarks_datasource
        elif self.landmark_source == 'mean':
            landmark_datasource = mean_landmarks_datasource
        elif self.landmark_source == 'random':
            landmark_datasource = random_landmarks_datasource

        return {
            'junior_landmarks_datasource': junior_landmarks_datasource,
            'senior_landmarks_datasource': senior_landmarks_datasource,
            'challenge_landmarks_datasource': challenge_landmarks_datasource,
            'mean_landmarks_datasource': mean_landmarks_datasource,
            'random_landmarks_datasource': random_landmarks_datasource,
            'landmark_datasource': landmark_datasource
        }
Example #5
0
 def spatial_transformation_volumetric_augmented(self, image):
     """
     The spatial image transformation with random augmentation.
     :return: The transformation.
     """
     dim = 3
     transformations_list = [translation.InputCenterToOrigin(dim),
                             scale.FitFixedAr(dim, self.image_size + [None], ignore_dim=[2])]
     if self.scale_factor[0] == 1.0 and self.scale_factor[1] == 1.0:
         # if no scale_factor, randomly shift by certain value
         transformations_list.append(translation.Random(dim, [self.image_size[0] * 0.35, self.image_size[1] * 0.35, 0.0]))
     else:
         # else, randomly move in imag size
         move_factor = [(1.0 - self.scale_factor[i]) * 0.5 for i in [0, 1]]
         transformations_list.append(translation.Random(dim, [self.image_size[0] * move_factor[0], self.image_size[1] * move_factor[1], 0]))
         transformations_list.append(scale.Fixed(dim, self.scale_factor + [1.0]))
     transformations_list.extend([flip.Random(dim, [0.5, 0.5, 0.0]),
                                  rotation.Random(dim, [0., 0., math.pi]),
                                  scale.RandomUniform(dim, 0.25, ignore_dim=[2]),
                                  scale.Random(dim, [0.25, 0.25, 0.0]),
                                  translation.OriginToOutputCenter(dim, self.image_size + [self.num_frames]),
                                  deformation.Output(dim, [6, 6, 4], [10, 10, 0], self.image_size + [self.num_frames])])
     comp = composite.Composite(dim, transformations_list,
                                name='image_transformation_comp',
                                kwparents={'image': image})
     return LambdaNode(lambda comp: sitk.DisplacementFieldTransform(sitk.TransformToDisplacementField(comp, sitk.sitkVectorFloat64, size=self.image_size + [self.num_frames])),
                       name='image_transformation',
                       kwparents={'comp': comp})
Example #6
0
 def spatial_transformation_augmented(self, image):
     """
     The spatial image transformation with random augmentation.
     :return: The transformation.
     """
     # bring image to center and fit to AR
     transformations_list = [translation.InputCenterToOrigin(self.dim),
                             scale.FitFixedAr(self.dim, self.image_size)]
     if self.scale_factor[0] == 1.0 and self.scale_factor[1] == 1.0:
         # if no scale_factor, randomly shift by certain value
         transformations_list.append(translation.Random(self.dim, [self.image_size[0] * 0.35, self.image_size[1] * 0.35]))
     else:
         # else, randomly move in imag size
         move_factor = [(1.0 - self.scale_factor[i]) * 0.5 for i in [0, 1]]
         transformations_list.append(translation.Random(self.dim, [self.image_size[0] * move_factor[0], self.image_size[1] * move_factor[1]]))
         transformations_list.append(scale.Fixed(self.dim, self.scale_factor))
     transformations_list.extend([flip.Random(self.dim, [0.5, 0.5]),
                                  rotation.Random(self.dim, [math.pi]),
                                  scale.RandomUniform(self.dim, 0.25),
                                  scale.Random(self.dim, [0.25, 0.25]),
                                  translation.OriginToOutputCenter(self.dim, self.image_size),
                                  deformation.Output(self.dim, [8, 8], 10, self.image_size)])
     comp = composite.Composite(self.dim, transformations_list,
                                name='image_transformation_comp',
                                kwparents={'image': image})
     return LambdaNode(lambda comp: sitk.DisplacementFieldTransform(sitk.TransformToDisplacementField(comp, sitk.sitkVectorFloat64, size=self.image_size)),
                       name='image_transformation',
                       kwparents={'comp': comp})
 def datasources(self, iterator, cached):
     """
     Returns the data sources that load data.
     {
     'image:' CachedImageDataSource that loads the image files.
     'labels:' CachedImageDataSource that loads the groundtruth labels.
     'landmarks:' LandmarkDataSource that loads the landmark coordinates.
     }
     :param iterator: The dataset iterator.
     :param cached: If true, use CachedImageDataSource, else ImageDataSource.
     :return: A dict of data sources.
     """
     datasources_dict = {}
     image_data_source = CachedImageDataSource if cached else ImageDataSource
     datasources_dict['image'] = image_data_source(
         self.image_base_folder,
         '',
         '',
         '.nii.gz',
         set_zero_origin=False,
         set_identity_direction=False,
         set_identity_spacing=False,
         sitk_pixel_type=sitk.sitkInt16,
         preprocessing=self.preprocessing,
         name='image',
         parents=[iterator])
     if self.generate_landmark_mask:
         datasources_dict['landmark_mask'] = LambdaNode(
             self.landmark_mask_preprocessing,
             name='image',
             parents=[datasources_dict['image']])
     if self.generate_labels or self.generate_single_vertebrae:
         datasources_dict['labels'] = image_data_source(
             self.image_base_folder,
             '',
             '_seg',
             '.nii.gz',
             set_zero_origin=False,
             set_identity_direction=False,
             set_identity_spacing=False,
             sitk_pixel_type=sitk.sitkUInt8,
             name='labels',
             parents=[iterator])
     if self.generate_landmarks or self.generate_heatmaps or self.generate_spine_heatmap or self.generate_single_vertebrae or self.generate_single_vertebrae_heatmap or (
             self.translate_to_center_landmarks
             and not self.load_spine_landmarks):
         datasources_dict['landmarks'] = LandmarkDataSource(
             self.landmarks_file,
             self.num_landmarks,
             self.dim,
             name='landmarks',
             parents=[iterator])
     if self.load_spine_landmarks:
         datasources_dict['spine_landmarks'] = LandmarkDataSource(
             self.spine_landmarks_file,
             1,
             self.dim,
             name='spine_landmarks',
             parents=[iterator])
     return datasources_dict
 def spatial_transformationx(self, iterator, datasources, image_size):
     """
     The spatial image transformation without random augmentation.
     :param datasources: datasources dict.
     :return: The transformation.
     """
     transformation_list = []
     kwparents = {'image': datasources['image'], 'output_size': image_size}
     if self.translate_to_center_landmarks:
         if 'spine_landmarks' in datasources:
             kwparents['landmarks'] = datasources['spine_landmarks']
         else:
             kwparents['landmarks'] = datasources['landmarks']
         transformation_list.append(translation.InputCenterToOrigin(self.dim, used_dimensions=[False, False, True]))
         transformation_list.append(landmark.Center(self.dim, True, used_dimensions=[True, True, False]))
     elif self.generate_single_vertebrae or self.generate_single_vertebrae_heatmap:
         single_landmark = LambdaNode(lambda id_dict, landmarks: [landmarks[int(id_dict['landmark_id'])]],
                                      parents=[iterator, datasources['landmarks']])
         kwparents['landmarks'] = single_landmark
         transformation_list.append(landmark.Center(self.dim, True))
         transformation_list.append(translation.Fixed(self.dim, [0, 20, 0]))
     else:
         transformation_list.append(translation.InputCenterToOrigin(self.dim))
     transformation_list.append(translation.OriginToOutputCenter(self.dim, None, self.image_spacing))
     comp = composite.Composite(self.dim, transformation_list, name='image', kwparents=kwparents)
     return comp
Example #9
0
    def dataset_val(self):
        """
        Returns the validation dataset for videos. No random augmentation is performed.
        :return: The validation dataset.
        """
        dim = 3
        full_video_frame_list_image = VideoFrameList(
            self.video_frame_list_file_name,
            self.num_frames - 1,
            0,
            border_mode='valid',
            random_start=False,
            random_skip_probability=0.0)
        iterator = 'image_ids'
        iterator_postprocessing = LambdaNode(
            lambda x: full_video_frame_list_image.get_id_dict_list(
                x['video_id'], x['frame_id']),
            parents=[iterator])

        sources = self.datasources(iterator_postprocessing)
        image_key = 'image'
        image_transformation = self.spatial_transformation_volumetric(
            sources[image_key])
        generators = self.data_generators(dim, sources, image_transformation,
                                          None)
        final_generators = self.all_generators_post_processing(
            generators, False)

        return GraphDataset(data_generators=list(final_generators.values()),
                            data_sources=list(sources.values()),
                            transformations=[image_transformation],
                            iterator=iterator,
                            debug_image_folder='debug_train'
                            if self.save_debug_images else None)
Example #10
0
    def dataset_train(self):
        """
        Returns the training dataset. Random augmentation is performed.
        :return: The training dataset.
        """
        iterator = self.iterator(self.train_file, True)
        sources = self.datasources(iterator, False, False, self.preprocessing_random, 8192)
        if self.use_variable_image_size:
            image_size = ImageSizeGenerator(self.dim, [None] * 3, self.image_spacing, valid_output_sizes=[self.valid_output_sizes_x, self.valid_output_sizes_y, self.valid_output_sizes_z], name='output_size', kwparents={'extent': sources['landmarks_bb_extent']})
            if self.crop_randomly_smaller:
                image_size = LambdaNode(self.crop_randomly_smaller_image_size, name='output_size', parents=[image_size])
        else:
            image_size = LambdaNode(lambda: self.image_size, name='output_size')
        reference_transformation = self.spatial_transformation_augmented(iterator, sources, image_size)
        generators = self.data_generators(iterator, sources, reference_transformation, self.postprocessing_random, True, image_size, self.crop_image_top_bottom)
        generators['image_id'] = LambdaNode(lambda d: np.array(d['image_id']), name='image_id', parents=[iterator])

        return GraphDataset(data_generators=list(generators.values()),
                            data_sources=list(sources.values()),
                            transformations=[reference_transformation],
                            iterator=iterator,
                            debug_image_folder='debug_train' if self.save_debug_images else None)
    def data_generators(self, iterator, datasources, transformation,
                        image_post_processing,
                        random_translation_single_landmark, image_size):
        """
        Returns the data generators that process one input. See datasources() for dict values.
        :param datasources: datasources dict.
        :param transformation: transformation.
        :param image_post_processing: The np postprocessing function for the image data generator.
        :return: A dict of data generators.
        """
        generators_dict = {}
        generators_dict['image'] = ImageGenerator(
            self.dim,
            image_size,
            self.image_spacing,
            interpolator='linear',
            post_processing_np=image_post_processing,
            data_format=self.data_format,
            resample_default_pixel_value=self.image_default_pixel_value,
            name='image',
            parents=[datasources['image'], transformation])
        if self.generate_landmark_mask:
            generators_dict['landmark_mask'] = ImageGenerator(
                self.dim,
                image_size,
                self.image_spacing,
                interpolator='nearest',
                data_format=self.data_format,
                resample_default_pixel_value=0,
                name='landmark_mask',
                parents=[datasources['landmark_mask'], transformation])
        if self.generate_labels or self.generate_single_vertebrae:
            generators_dict['labels'] = ImageGenerator(
                self.dim,
                image_size,
                self.image_spacing,
                interpolator='nearest',
                post_processing_np=self.split_labels,
                data_format=self.data_format,
                name='labels',
                parents=[datasources['labels'], transformation])
        if self.generate_heatmaps or self.generate_spine_heatmap:
            generators_dict['heatmaps'] = LandmarkGeneratorHeatmap(
                self.dim,
                image_size,
                self.image_spacing,
                sigma=self.heatmap_sigma,
                scale_factor=1.0,
                normalize_center=True,
                data_format=self.data_format,
                name='heatmaps',
                parents=[datasources['landmarks'], transformation])
        if self.generate_landmarks:
            generators_dict['landmarks'] = LandmarkGenerator(
                self.dim,
                image_size,
                self.image_spacing,
                data_format=self.data_format,
                name='landmarks',
                parents=[datasources['landmarks'], transformation])
        if self.generate_single_vertebrae_heatmap:
            single_landmark = LambdaNode(
                lambda id_dict, landmarks: landmarks[int(id_dict[
                    'landmark_id']):int(id_dict['landmark_id']) + 1],
                name='single_landmark',
                parents=[iterator, datasources['landmarks']])
            if random_translation_single_landmark:
                single_landmark = LambdaNode(
                    lambda l: [
                        Landmark(
                            l[0].coords + float_uniform(
                                -self.random_translation_single_landmark, self.
                                random_translation_single_landmark, [self.dim
                                                                     ]), True)
                    ],
                    name='single_landmark_translation',
                    parents=[single_landmark])
            generators_dict['single_heatmap'] = LandmarkGeneratorHeatmap(
                self.dim,
                image_size,
                self.image_spacing,
                sigma=self.heatmap_sigma,
                scale_factor=1.0,
                normalize_center=True,
                data_format=self.data_format,
                name='single_heatmap',
                parents=[single_landmark, transformation])
        if self.generate_single_vertebrae:
            if self.data_format == 'channels_first':
                generators_dict['single_label'] = LambdaNode(
                    lambda id_dict, images: images[int(id_dict[
                        'landmark_id']) + 1:int(id_dict['landmark_id']) + 2,
                                                   ...],
                    name='single_label',
                    parents=[iterator, generators_dict['labels']])
            else:
                generators_dict['single_label'] = LambdaNode(
                    lambda id_dict, images: images[...,
                                                   int(id_dict['landmark_id'])
                                                   + 1:int(id_dict[
                                                       'landmark_id']) + 2],
                    name='single_label',
                    parents=[iterator, generators_dict['labels']])
        if self.generate_spine_heatmap:
            generators_dict['spine_heatmap'] = LambdaNode(
                lambda images: gaussian(np.sum(images,
                                               axis=0 if self.data_format ==
                                               'channels_first' else -1,
                                               keepdims=True),
                                        sigma=self.spine_heatmap_sigma),
                name='spine_heatmap',
                parents=[generators_dict['heatmaps']])

        return generators_dict
    def data_generators(self, iterator, datasources, transformation, image_post_processing, random_translation_single_landmark, image_size, crop=False):
        """
        Returns the data generators that process one input. See datasources() for dict values.
        :param datasources: datasources dict.
        :param transformation: transformation.
        :param image_post_processing: The np postprocessing function for the image data generator.
        :return: A dict of data generators.
        """
        generators_dict = {}
        kwparents = {'output_size': image_size}
        image_datasource = datasources['image'] if not crop else LambdaNode(self.landmark_based_crop, name='image_cropped', kwparents={'image': datasources['image'], 'landmarks': datasources['landmarks']})
        generators_dict['image'] = ImageGenerator(self.dim,
                                                  None,
                                                  self.image_spacing,
                                                  interpolator='linear',
                                                  post_processing_np=image_post_processing,
                                                  data_format=self.data_format,
                                                  resample_default_pixel_value=self.image_default_pixel_value,
                                                  np_pixel_type=self.output_image_type,
                                                  name='image',
                                                  parents=[image_datasource, transformation],
                                                  kwparents=kwparents)
        # generators_dict['image'] = ImageGenerator(self.dim,
        #                                           None,
        #                                           self.image_spacing,
        #                                           interpolator='linear',
        #                                           post_processing_np=image_post_processing,
        #                                           data_format=self.data_format,
        #                                           resample_default_pixel_value=self.image_default_pixel_value,
        #                                           np_pixel_type=self.output_image_type,
        #                                           name='image_cropped',
        #                                           parents=[LambdaNode(self.landmark_based_crop, name='image_cropped', kwparents={'image': datasources['image'], 'landmarks': datasources['landmarks']}), transformation],
        #                                           kwparents=kwparents)
        if self.generate_landmark_mask:
            generators_dict['landmark_mask'] = ImageGenerator(self.dim,
                                                              None,
                                                              self.image_spacing,
                                                              interpolator='nearest',
                                                              data_format=self.data_format,
                                                              resample_default_pixel_value=0,
                                                              name='landmark_mask',
                                                              parents=[datasources['landmark_mask'], transformation],
                                                              kwparents=kwparents)
        if self.generate_labels:
            generators_dict['labels'] = ImageGenerator(self.dim,
                                                       None,
                                                       self.image_spacing,
                                                       interpolator='nearest',
                                                       post_processing_np=self.split_labels,
                                                       data_format=self.data_format,
                                                       name='labels',
                                                       parents=[datasources['labels'], transformation],
                                                       kwparents=kwparents)
        if self.generate_heatmaps or self.generate_spine_heatmap:
            generators_dict['heatmaps'] = LandmarkGeneratorHeatmap(self.dim,
                                                                   None,
                                                                   self.image_spacing,
                                                                   sigma=self.heatmap_sigma,
                                                                   scale_factor=1.0,
                                                                   normalize_center=True,
                                                                   data_format=self.data_format,
                                                                   name='heatmaps',
                                                                   parents=[datasources['landmarks'], transformation],
                                                                   kwparents=kwparents)
        if self.generate_landmarks:
            generators_dict['landmarks'] = LandmarkGenerator(self.dim,
                                                             None,
                                                             self.image_spacing,
                                                             data_format=self.data_format,
                                                             name='landmarks',
                                                             parents=[datasources['landmarks'], transformation],
                                                             kwparents=kwparents)
        if self.generate_single_vertebrae_heatmap:
            single_landmark = LambdaNode(lambda id_dict, landmarks: landmarks[int(id_dict['landmark_id']):int(id_dict['landmark_id']) + 1],
                                         name='single_landmark',
                                         parents=[iterator, datasources['landmarks']])
            if random_translation_single_landmark:
                single_landmark = LambdaNode(lambda l: [Landmark(l[0].coords + float_uniform(-self.random_translation_single_landmark, self.random_translation_single_landmark, [self.dim]), True)],
                                             name='single_landmark_translation',
                                             parents=[single_landmark])
            generators_dict['single_heatmap'] = LandmarkGeneratorHeatmap(self.dim,
                                                                         None,
                                                                         self.image_spacing,
                                                                         sigma=self.single_heatmap_sigma,
                                                                         scale_factor=1.0,
                                                                         normalize_center=True,
                                                                         data_format=self.data_format,
                                                                         np_pixel_type=self.output_image_type,
                                                                         name='single_heatmap',
                                                                         parents=[single_landmark, transformation],
                                                                         kwparents=kwparents)
        if self.generate_single_vertebrae:
            if self.generate_labels:
                if self.data_format == 'channels_first':
                    generators_dict['single_label'] = LambdaNode(lambda id_dict, images: images[int(id_dict['landmark_id']) + 1:int(id_dict['landmark_id']) + 2, ...],
                                                                 name='single_label',
                                                                 parents=[iterator, generators_dict['labels']])
                else:
                    generators_dict['single_label'] = LambdaNode(lambda id_dict, images: images[..., int(id_dict['landmark_id']) + 1:int(id_dict['landmark_id']) + 2],
                                                                 name='single_label',
                                                                 parents=[iterator, generators_dict['labels']])
            else:
                labels_unsmoothed = ImageGenerator(self.dim,
                                                   None,
                                                   self.image_spacing,
                                                   interpolator='nearest',
                                                   post_processing_np=None,
                                                   data_format=self.data_format,
                                                   name='labels_unsmoothed',
                                                   parents=[datasources['labels'], transformation],
                                                   kwparents=kwparents)
                generators_dict['single_label'] = LambdaNode(lambda id_dict, labels: self.split_and_smooth_single_label(labels, int(id_dict['landmark_id'])),
                                                             name='single_label',
                                                             parents=[iterator, labels_unsmoothed])
        if self.generate_spine_heatmap:
            generators_dict['spine_heatmap'] = LambdaNode(lambda images: normalize(gaussian(np.sum(images, axis=0 if self.data_format == 'channels_first' else -1, keepdims=True), sigma=self.spine_heatmap_sigma), out_range=(0, 1)),
                                                          name='spine_heatmap',
                                                          parents=[generators_dict['heatmaps']])

        return generators_dict
 def datasources(self, iterator, image_cached, labels_cached, image_preprocessing, cache_size):
     """
     Returns the data sources that load data.
     {
     'image:' CachedImageDataSource that loads the image files.
     'labels:' CachedImageDataSource that loads the groundtruth labels.
     'landmarks:' LandmarkDataSource that loads the landmark coordinates.
     }
     :param iterator: The dataset iterator.
     :param cached: If true, use CachedImageDataSource, else ImageDataSource.
     :return: A dict of data sources.
     """
     datasources_dict = {}
     if image_cached:
         image_data_source = CachedImageDataSource
         image_source_kwargs = {'cache_maxsize': cache_size}
     else:
         image_data_source = ImageDataSource
         image_source_kwargs = {}
     datasources_dict['image'] = image_data_source(self.image_base_folder,
                                                   '',
                                                   '',
                                                   '.nii.gz',
                                                   set_zero_origin=False,
                                                   set_identity_direction=False,
                                                   set_identity_spacing=False,
                                                   sitk_pixel_type=sitk.sitkInt16,
                                                   preprocessing=image_preprocessing,
                                                   name='image',
                                                   parents=[iterator],
                                                   **image_source_kwargs)
     if self.generate_landmark_mask:
         datasources_dict['landmark_mask'] = LambdaNode(self.landmark_mask_preprocessing,
                                                        name='landmark_mask',
                                                        parents=[datasources_dict['image']])
     if self.generate_labels or self.generate_single_vertebrae:
         if labels_cached:
             image_data_source = CachedImageDataSource
             image_source_kwargs = {'cache_maxsize': cache_size}
         else:
             image_data_source = ImageDataSource
             image_source_kwargs = {}
         datasources_dict['labels'] = image_data_source(self.image_base_folder,
                                                        '',
                                                        '_seg',
                                                        '.nii.gz',
                                                        set_zero_origin=False,
                                                        set_identity_direction=False,
                                                        set_identity_spacing=False,
                                                        sitk_pixel_type=sitk.sitkUInt8,
                                                        name='labels',
                                                        parents=[iterator],
                                                        **image_source_kwargs)
     if self.generate_landmarks or self.generate_heatmaps or self.generate_spine_heatmap or self.generate_single_vertebrae or self.generate_single_vertebrae_heatmap or (self.translate_to_center_landmarks and not (self.load_spine_landmarks or self.load_spine_bbs)):
         datasources_dict['landmarks'] = LandmarkDataSource(self.landmarks_file,
                                                            self.num_landmarks,
                                                            self.dim,
                                                            name='landmarks',
                                                            parents=[iterator])
         datasources_dict['landmarks_bb'] = LambdaNode(self.image_landmark_bounding_box, name='landmarks_bb', parents=[datasources_dict['image'], datasources_dict['landmarks']])
         datasources_dict['landmarks_bb_start'] = LambdaNode(lambda x: x[0], name='landmarks_bb_start', parents=[datasources_dict['landmarks_bb']])
         datasources_dict['landmarks_bb_extent'] = LambdaNode(lambda x: x[1], name='landmarks_bb_extent', parents=[datasources_dict['landmarks_bb']])
     if self.load_spine_landmarks:
         datasources_dict['spine_landmarks'] = LandmarkDataSource(self.spine_landmarks_file, 1, self.dim, name='spine_landmarks', parents=[iterator])
     if self.load_spine_bbs:
         datasources_dict['spine_bb'] = LabelDatasource(self.spine_bbs_file, name='spine_landmarks', parents=[iterator])
         datasources_dict['landmarks_bb'] = LambdaNode(self.image_bounding_box, name='landmarks_bb', parents=[datasources_dict['image'], datasources_dict['spine_bb']])
         datasources_dict['landmarks_bb_start'] = LambdaNode(lambda x: x[0], name='landmarks_bb_start', parents=[datasources_dict['landmarks_bb']])
         datasources_dict['landmarks_bb_extent'] = LambdaNode(lambda x: x[1], name='landmarks_bb_extent', parents=[datasources_dict['landmarks_bb']])
     return datasources_dict