def load_idl(file_name, num_landmarks, dim):
    landmarks_dict = {}
    with open(file_name, 'r') as file:
        for line in file.readlines():
            id_match = re.search('"(.*)"', line)
            id = id_match.groups()[0]
            coords_matches = re.findall('\((\d*),(\d*),(\d*)\)', line)
            assert num_landmarks == len(
                coords_matches
            ), 'number of row entries and landmark coordinates do not match'
            if dim == 2:
                landmarks = [
                    Landmark(
                        np.array(
                            [float(coords_match[0]),
                             float(coords_match[1])], np.float32))
                    for coords_match in coords_matches
                ]
            elif dim == 3:
                landmarks = [
                    Landmark(
                        np.array([
                            float(coords_match[0]),
                            float(coords_match[1]),
                            float(coords_match[2])
                        ], np.float32)) for coords_match in coords_matches
                ]
            landmarks_dict[id] = landmarks
    return landmarks_dict
Пример #2
0
 def reshift_landmarks(self, curr_landmarks):
     if (not curr_landmarks[0].is_valid) and curr_landmarks[7].is_valid:
         if (not curr_landmarks[6].is_valid) and curr_landmarks[5].is_valid:
             # shift c indizes up
             print('shift c indizes up')
             curr_landmarks = [
                 Landmark([np.nan] * 3, is_valid=False)
             ] + curr_landmarks[0:5] + curr_landmarks[6:26]
     if (not curr_landmarks[7].is_valid) and curr_landmarks[19].is_valid:
         if (not curr_landmarks[18].is_valid
             ) and curr_landmarks[17].is_valid:
             # shift l indizes up
             print('shift t indizes up')
             curr_landmarks = curr_landmarks[0:7] + [
                 Landmark([np.nan] * 3, is_valid=False)
             ] + curr_landmarks[7:18] + curr_landmarks[19:26]
         elif curr_landmarks[25].is_valid:
             # shift l indizes down
             print('shift t indizes down')
             curr_landmarks = curr_landmarks[0:7] + curr_landmarks[8:19] + [
                 curr_landmarks[25]
             ] + curr_landmarks[19:25] + [
                 Landmark([np.nan] * 3, is_valid=False)
             ]
     return curr_landmarks
Пример #3
0
def load_multi_csv(file_name, num_landmarks, dim=2):
    landmarks_dict = {}
    with open(file_name, 'r') as csv_file:
        reader = csv.reader(csv_file)
        for row in reader:
            name = row[0]
            if name in landmarks_dict:
                landmarks_dict_per_image = landmarks_dict[name]
            else:
                landmarks_dict_per_image = {}
                landmarks_dict[name] = landmarks_dict_per_image

            person_id = row[1]
            #if int(row[1]) != 0:
            #    continue
            landmarks = []
            #print(len(points_dict), name)
            for i in range(2, dim * num_landmarks + 2, dim):
                # print(i)
                if np.isnan(float(row[i])):
                    landmark = Landmark(None, False)
                else:
                    if dim == 2:
                        coords = np.array([float(row[i]), float(row[i + 1])], np.float32)
                    elif dim == 3:
                        coords = np.array([float(row[i]), float(row[i + 1]), float(row[i + 2])], np.float32)
                    landmark = Landmark(coords)
                landmarks.append(landmark)
            landmarks_dict_per_image[person_id] = landmarks
    return landmarks_dict
def load_csv(file_name, num_landmarks, dim):
    landmarks_dict = {}
    with open(file_name, 'r') as file:
        reader = csv.reader(file)
        for row in reader:
            id = row[0]
            landmarks = []
            num_entries = dim * num_landmarks + 1
            assert num_entries == len(
                row
            ), 'number of row entries and landmark coordinates do not match'
            # print(len(points_dict), name)
            for i in range(1, dim * num_landmarks + 1, dim):
                # print(i)
                if np.isnan(float(row[i])):
                    landmark = Landmark(None, False)
                else:
                    if dim == 2:
                        coords = np.array(
                            [float(row[i]), float(row[i + 1])], np.float32)
                    elif dim == 3:
                        coords = np.array([
                            float(row[i]),
                            float(row[i + 1]),
                            float(row[i + 2])
                        ], np.float32)
                    landmark = Landmark(coords)
                landmarks.append(landmark)
            landmarks_dict[id] = landmarks
    return landmarks_dict
Пример #5
0
    def get_landmark(self,
                     image,
                     transformation=None,
                     reference_sitk=None,
                     output_spacing=None):
        """
        Returns a single landmark for the given parameters. The coordinates of the landmark are the maximum
        of the image, possibly transformed with the transformation parameter.
        :param image: The np array with a single channel.
        :param reference_sitk: The reference sitk image.
        :param output_spacing: The output spacing of the np array.
        :param transformation: The transformation. If transformation is None, the prediction np array will not be transformed.
        :return: A Landmark object.
        """
        output_spacing = output_spacing or [1] * image.ndim
        if self.return_multiple_maxima:
            landmarks = []
            value_coord_pairs = self.get_multiple_maximum_coordinates(image)
            for value, coord in value_coord_pairs:
                coord = np.flip(coord, axis=0)
                coord *= output_spacing
                coord = utils.landmark.transform.transform_coords(
                    coord, transformation)
                landmarks.append(
                    Landmark(coords=coord, is_valid=True, scale=1,
                             value=value))
            return landmarks
        if transformation is not None:
            if self.invert_transformation:
                # transform prediction back to input image resolution, if specified.
                transformed_sitk = self.get_transformed_image_sitk(
                    image,
                    reference_sitk=reference_sitk,
                    output_spacing=output_spacing,
                    transformation=transformation)
                transformed_np = utils.sitk_np.sitk_to_np_no_copy(
                    transformed_sitk[0])
                value, coord = utils.np_image.find_maximum_in_image(
                    transformed_np)
                coord = np.flip(coord, axis=0)
                coord = coord.astype(np.float32)
                coord *= np.array(reference_sitk.GetSpacing())
            else:
                # search for subpixel accurate maximum in image
                value, coord = utils.np_image.find_quadratic_subpixel_maximum_in_image(
                    image)
                coord = np.flip(coord, axis=0)
                coord *= output_spacing
                coord = utils.landmark.transform.transform_coords(
                    coord, transformation)
        else:
            # just take maximum of image
            value, coord = utils.np_image.find_maximum_in_image(image)
            coord = np.flip(coord, axis=0)

        return Landmark(coords=coord, is_valid=True, scale=1, value=value)
Пример #6
0
 def finalize_landmarks(self, final_landmarks, min_idx, max_idx):
     """
     Appends invalid landmarks to the final landmark sequence such that the list contains self.num_landmark entries, and the list index is the landmark index.
     :param final_landmarks: The final landmark sequence.
     :param min_idx: The first landmark index.
     :param max_idx: The last landmark index.
     :return: A landmark sequence with self.num_landmark, where not found landmarks are set to be invalid.
     """
     # fill landmarks with invalid coordinates
     return [Landmark(coords=[np.nan] * 3, is_valid=False)
             ] * min_idx + final_landmarks + [
                 Landmark(coords=[np.nan] * 3, is_valid=False)
             ] * (self.num_landmarks - max_idx - 1)
 def project_landmarks(self, landmarks, axis):
     """
     Project landmarks to an axis.
     :param landmarks: The landmarks list.
     :param axis: The axis to project to.
     :return: List of projected landmarks.
     """
     projected_landmarks = []
     for l in landmarks:
         if not l.is_valid:
             projected_landmarks.append(Landmark(is_valid=False))
         else:
             projected_landmarks.append(Landmark([l.coords[i] for i in range(len(l.coords)) if i != axis]))
     return projected_landmarks
Пример #8
0
def load_lml(file_name, num_landmarks, landmark_ids):
    landmarks = [Landmark() for _ in range(num_landmarks)]
    with open(file_name, 'r') as file:
        for line in file.readlines():
            if line.startswith('#'):
                continue
            tokens = line.split('\t')
            coords = [float(tokens[2]), float(tokens[3]), float(tokens[4])]
            current_id = int(tokens[0])
            if current_id not in landmark_ids:
                print('Warning: invalid id {} for file {}'.format(current_id, file_name))
                continue
            landmark_index = landmark_ids.index(int(tokens[0]))
            landmarks[landmark_index] = Landmark(coords)
    return landmarks
Пример #9
0
 def get_landmarks(self, image_id):
     """
     Returns the list of landmarks for a given image_id.
     :param image_id: The image_id.
     """
     try:
         return self.point_list[image_id]
     except KeyError:
         if self.silent_not_found:
             return [Landmark() for _ in range(self.num_points)]
         else:
             raise
Пример #10
0
 def get_landmarks(self, image_id, instance_id):
     """
     Returns the landmarks for a given image_id and optionally instance_id.
     :param image_id: Used as key for the landmarks file.
     :param instance_id: Used as key for the instance_id.
     :return: List of list of Landmarks(), if multiple is True. List of Landmarks(), otherwise.
     """
     try:
         if self.multiple:
             return list(self.point_list[image_id].values())
         else:
             #instance_id = kwargs.get('instance_id')
             return self.point_list[image_id][instance_id]
     except KeyError:
         if self.silent_not_found:
             if self.multiple:
                 return [[Landmark() for _ in range(self.num_points)]]
             else:
                 return [Landmark() for _ in range(self.num_points)]
         else:
             raise
def filter_landmarks_top_bottom(curr_landmarks, input_image):
    image_extent = [
        spacing * size for spacing, size in zip(input_image.GetSpacing(),
                                                input_image.GetSize())
    ]
    filtered_landmarks = []
    z_distance_top_bottom = 10
    for l in curr_landmarks:
        if z_distance_top_bottom < l.coords[
                2] < image_extent[2] - z_distance_top_bottom:
            filtered_landmarks.append(l)
        else:
            filtered_landmarks.append(
                Landmark(coords=[np.nan] * 3, is_valid=False))
    return filtered_landmarks
Пример #12
0
 def path_to_landmarks(self, path, local_heatmap_maxima):
     """
     Converts a path to a list of landmarks. The length of the list is the same as the number of landmarks.
     If a landmark is not valid, a Landmark with np.nan coordinates and is_valid = False is inserted at its position.
     :param path: The path.
     :param local_heatmap_maxima: The local heatmap maxima.
     :return: List of landmarks.
     """
     landmarks = [
         Landmark(coords=[np.nan] * 3, is_valid=False)
         for _ in range(self.num_landmarks)
     ]
     for node in path:
         if node == 's' or node == 't':
             continue
         landmark_index, maxima_index = self.vertex_name_to_indizes(node)
         landmarks[landmark_index] = local_heatmap_maxima[landmark_index][
             maxima_index]
     return landmarks
    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3

        landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            input_image = datasources['image']

            image, prediction, transformation = self.test_full_image(dataset_entry)
            predictions_sitk = utils.sitk_image.transform_np_output_to_sitk_input(output_image=prediction,
                                                                                  output_spacing=self.image_spacing,
                                                                                  channel_axis=channel_axis,
                                                                                  input_image_sitk=input_image,
                                                                                  transform=transformation,
                                                                                  interpolator='linear',
                                                                                  output_pixel_type=sitk.sitkUInt8)
            if self.save_output_images:
                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                heatmap_normalization_mode = (0, 1)
                utils.io.image.write_multichannel_np(image, self.output_file_for_current_iteration(current_id + '_input.mha'), normalization_mode='min_max', split_channel_axis=True, sitk_image_mode='default', data_format=self.data_format, image_type=np.uint8, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction, self.output_file_for_current_iteration(current_id + '_prediction.mha'), normalization_mode=heatmap_normalization_mode, split_channel_axis=True, data_format=self.data_format, image_type=np.uint8, spacing=self.image_spacing, origin=origin)
                utils.io.image.write(predictions_sitk[0], self.output_file_for_current_iteration(current_id + '_prediction_original.mha'))

            predictions_com = input_image.TransformContinuousIndexToPhysicalPoint(list(reversed(utils.np_image.center_of_mass(utils.sitk_np.sitk_to_np_no_copy(predictions_sitk[0])))))
            landmarks[current_id] = [Landmark(predictions_com)]
            print_progress_bar(i, num_entries, prefix='Testing ', suffix=' complete')

        utils.io.landmark.save_points_csv(landmarks, self.output_file_for_current_iteration('points.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            self.val_loss_aggregator.finalize(self.current_iter)
def add_landmarks_from_neighbors(local_maxima_landmarks):
    local_maxima_landmarks = deepcopy(local_maxima_landmarks)
    duplicate_penalty = 0.1
    for i in range(2, 6):
        local_maxima_landmarks[i + 1].extend([
            Landmark(coords=l.coords, value=l.value * duplicate_penalty)
            for l in local_maxima_landmarks[i]
        ])
        local_maxima_landmarks[i].extend([
            Landmark(coords=l.coords, value=l.value * duplicate_penalty)
            for l in local_maxima_landmarks[i + 1]
        ])
    for i in range(8, 18):
        local_maxima_landmarks[i + 1].extend([
            Landmark(coords=l.coords, value=l.value * duplicate_penalty)
            for l in local_maxima_landmarks[i]
        ])
        local_maxima_landmarks[i].extend([
            Landmark(coords=l.coords, value=l.value * duplicate_penalty)
            for l in local_maxima_landmarks[i + 1]
        ])
    local_maxima_landmarks[25].extend([
        Landmark(coords=l.coords, value=l.value)
        for l in local_maxima_landmarks[18]
    ])
    local_maxima_landmarks[18].extend([
        Landmark(coords=l.coords, value=l.value)
        for l in local_maxima_landmarks[25]
    ])
    for i in range(20, 24):
        local_maxima_landmarks[i + 1].extend([
            Landmark(coords=l.coords, value=l.value * duplicate_penalty)
            for l in local_maxima_landmarks[i]
        ])
        local_maxima_landmarks[i].extend([
            Landmark(coords=l.coords, value=l.value * duplicate_penalty)
            for l in local_maxima_landmarks[i + 1]
        ])
    return local_maxima_landmarks
 def visualize_offsets_to_reference_projections(self, image_sitk, reference_groundtruth, predicted_per_image_id_list, groundtruth_per_image_id, landmark_colors_list, filename):
     """
     Visualize landmarks or landmark pairs onto projections of a given sitk image.
     :param image_sitk: The sitk image (that will be projected in case of 3D).
     :param reference_groundtruth: The reference_groundtruth for the image.
     :param predicted_per_image_id_list: The list of dictionaries of predicted landmarks.
     :param groundtruth_per_image_id: The dictionary of groundtruth landmarks.
     :param landmark_colors_list: List of list of landmark colors for each entry ofr landmarks_list. If None, use self.landmark_colors.
     :param filename: The filename to save the image to.
     """
     image_canvas_list = self.prepare_image_canvas_list(image_sitk)
     for image_id, groundtruth in groundtruth_per_image_id.items():
         for i, predicted_per_image_id in enumerate(predicted_per_image_id_list):
             landmark_colors = None if landmark_colors_list is None else landmark_colors_list[i]
             offsets = [Landmark(p.coords - g.coords + r.coords) for p, g, r in zip(predicted_per_image_id[image_id], groundtruth, reference_groundtruth)]
             projected_offset_list = self.project_landmarks_list(offsets)
             for image_canvas, projected_offsets in zip(image_canvas_list, projected_offset_list):
                 self.visualize_landmarks(image_canvas, projected_offsets, landmark_colors, self.annotations)
     # visualize black dots on original groundtruth
     projected_reference_groundtruth_list = self.project_landmarks_list(reference_groundtruth)
     for image_canvas, projected_reference_groundtruth in zip(image_canvas_list, projected_reference_groundtruth_list):
         self.visualize_landmarks(image_canvas, projected_reference_groundtruth, [(0, 0, 0) for _ in range(len(projected_reference_groundtruth))], self.annotations)
     image_canvas_merged = self.merge_image_canvas(image_canvas_list)
     self.save(image_canvas_merged, filename)
    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        if len(self.load_model_filenames) == 1:
            self.load_model(self.load_model_filenames[0])

        vis = LandmarkVisualizationMatplotlib(
            dim=3,
            annotations=dict([(i, f'C{i + 1}')
                              for i in range(7)] +  # 0-6: C1-C7
                             [(i, f'T{i - 6}')
                              for i in range(7, 19)] +  # 7-18: T1-12
                             [(i, f'L{i - 18}')
                              for i in range(19, 25)] +  # 19-24: L1-6
                             [(25, 'T13')]))  # 25: T13

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        heatmap_maxima = HeatmapTest(channel_axis,
                                     False,
                                     return_multiple_maxima=True,
                                     min_max_value=0.05,
                                     smoothing_sigma=2.0)

        with open('possible_successors.pickle', 'rb') as f:
            possible_successors = pickle.load(f)
        with open('units_distances.pickle', 'rb') as f:
            offsets_mean, distances_mean, distances_std = pickle.load(f)
        spine_postprocessing = SpinePostprocessingGraph(
            num_landmarks=self.num_landmarks,
            possible_successors=possible_successors,
            offsets_mean=offsets_mean,
            distances_mean=distances_mean,
            distances_std=distances_std,
            bias=2.0,
            l=0.2)

        landmarks = {}
        landmarks_no_postprocessing = {}
        for current_id in tqdm(self.image_id_list, desc='Testing'):
            try:
                dataset_entry = self.dataset_val.get({'image_id': current_id})
                print(current_id)
                datasources = dataset_entry['datasources']
                input_image = datasources['image']

                image, prediction, prediction_local, prediction_spatial, transformation = self.test_cropped_image(
                    dataset_entry)

                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                if self.save_output_images:
                    heatmap_normalization_mode = (-1, 1)
                    image_type = np.uint8
                    utils.io.image.write_multichannel_np(
                        image,
                        self.output_folder_handler.path(
                            'output', current_id + '_input.mha'),
                        output_normalization_mode='min_max',
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction.mha'),
                        output_normalization_mode=heatmap_normalization_mode,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction_rgb.mha'),
                        output_normalization_mode=(0, 1),
                        channel_layout_mode='channel_rgb',
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction_local,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction_local.mha'),
                        output_normalization_mode=heatmap_normalization_mode,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction_spatial,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction_spatial.mha'),
                        output_normalization_mode=heatmap_normalization_mode,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)

                local_maxima_landmarks = heatmap_maxima.get_landmarks(
                    prediction, input_image, self.image_spacing,
                    transformation)
                curr_landmarks_no_postprocessing = [
                    l[0] if len(l) > 0 else Landmark(coords=[np.nan] * 3,
                                                     is_valid=False)
                    for l in local_maxima_landmarks
                ]
                landmarks_no_postprocessing[
                    current_id] = curr_landmarks_no_postprocessing

                try:
                    local_maxima_landmarks = add_landmarks_from_neighbors(
                        local_maxima_landmarks)
                    curr_landmarks = spine_postprocessing.solve_local_heatmap_maxima(
                        local_maxima_landmarks)
                    curr_landmarks = reshift_landmarks(curr_landmarks)
                    curr_landmarks = filter_landmarks_top_bottom(
                        curr_landmarks, input_image)
                except Exception:
                    print('error in postprocessing', current_id)
                    traceback.print_exc(file=sys.stdout)
                    curr_landmarks = curr_landmarks_no_postprocessing
                landmarks[current_id] = curr_landmarks

                if self.save_output_images:
                    vis.visualize_landmark_projections(
                        input_image,
                        curr_landmarks_no_postprocessing,
                        filename=self.output_folder_handler.path(
                            'output', current_id + '_landmarks.png'))
                    vis.visualize_landmark_projections(
                        input_image,
                        curr_landmarks,
                        filename=self.output_folder_handler.path(
                            'output', current_id + '_landmarks_pp.png'))

                verse_landmarks = self.convert_landmarks_to_verse_indexing(
                    curr_landmarks, input_image)
                self.save_landmarks_verse_json(
                    verse_landmarks,
                    self.output_folder_handler.path(current_id + '_ctd.json'))
            except Exception:
                print('ERROR predicting', current_id)
                traceback.print_exc(file=sys.stdout)
                pass

        utils.io.landmark.save_points_csv(
            landmarks, self.output_folder_handler.path('landmarks.csv'))
        utils.io.landmark.save_points_csv(
            landmarks_no_postprocessing,
            self.output_folder_handler.path('landmarks_no_postprocessing.csv'))
        self.save_valid_landmarks_list(
            landmarks, self.output_folder_handler.path('valid_landmarks.csv'))
    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')
        vis = LandmarkVisualizationMatplotlib(dim=3,
                                              annotations=dict([(i, f'C{i + 1}') for i in range(7)] +        # 0-6: C1-C7
                                                               [(i, f'T{i - 6}') for i in range(7, 19)] +    # 7-18: T1-12
                                                               [(i, f'L{i - 18}') for i in range(19, 25)] +  # 19-24: L1-6
                                                               [(25, 'T13')]))                               # 25: T13

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        heatmap_maxima = HeatmapTest(channel_axis,
                                     False,
                                     return_multiple_maxima=True,
                                     min_max_value=0.05,
                                     smoothing_sigma=2.0)

        with open('possible_successors.pickle', 'rb') as f:
            possible_successors = pickle.load(f)
        with open('units_distances.pickle', 'rb') as f:
            offsets_mean, distances_mean, distances_std = pickle.load(f)
        spine_postprocessing = SpinePostprocessingGraph(num_landmarks=self.num_landmarks,
                                                        possible_successors=possible_successors,
                                                        offsets_mean=offsets_mean,
                                                        distances_mean=distances_mean,
                                                        distances_std=distances_std,
                                                        bias=2.0,
                                                        l=0.2)

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        landmark_statistics_no_postprocessing = LandmarkStatistics()
        landmarks_no_postprocessing = {}
        all_local_maxima_landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for _ in tqdm(range(num_entries), desc='Testing'):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            input_image = datasources['image']
            if self.has_validation_groundtruth:
                target_landmarks = datasources['landmarks']
            else:
                target_landmarks = None

            image, prediction, prediction_local, prediction_spatial, transformation = self.test_cropped_image(dataset_entry)

            origin = transformation.TransformPoint(np.zeros(3, np.float64))
            if self.save_output_images:
                heatmap_normalization_mode = (-1, 1)
                image_type = np.uint8
                utils.io.image.write_multichannel_np(image,self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_input.mha'), output_normalization_mode='min_max', sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction_local, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction_local.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction_spatial, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction_spatial.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)

            local_maxima_landmarks = heatmap_maxima.get_landmarks(prediction, input_image, self.image_spacing, transformation)

            # landmarks without postprocessing are the first local maxima (with the largest value)
            curr_landmarks_no_postprocessing = [l[0] if len(l) > 0 else Landmark(coords=[np.nan] * 3, is_valid=False)  for l in local_maxima_landmarks]
            landmarks_no_postprocessing[current_id] = curr_landmarks_no_postprocessing

            if self.has_validation_groundtruth:
                landmark_statistics_no_postprocessing.add_landmarks(current_id, curr_landmarks_no_postprocessing, target_landmarks)
                vis.visualize_landmark_projections(input_image, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_gt.png'))
                vis.visualize_prediction_groundtruth_projections(input_image, curr_landmarks_no_postprocessing, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks.png'))
            else:
                vis.visualize_landmark_projections(input_image, curr_landmarks_no_postprocessing, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks.png'))

            if self.evaluate_landmarks_postprocessing:
                try:
                    local_maxima_landmarks = add_landmarks_from_neighbors(local_maxima_landmarks)
                    curr_landmarks = spine_postprocessing.solve_local_heatmap_maxima(local_maxima_landmarks)
                    curr_landmarks = reshift_landmarks(curr_landmarks)
                    curr_landmarks = filter_landmarks_top_bottom(curr_landmarks, input_image)
                except Exception:
                    print('error in postprocessing', current_id)
                    curr_landmarks = curr_landmarks_no_postprocessing
                landmarks[current_id] = curr_landmarks

                if self.has_validation_groundtruth:
                    landmark_statistics.add_landmarks(current_id, curr_landmarks, target_landmarks)
                    vis.visualize_prediction_groundtruth_projections(input_image, curr_landmarks, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_pp.png'))
                else:
                    vis.visualize_landmark_projections(input_image, curr_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_pp.png'))

        utils.io.landmark.save_points_csv(landmarks, self.output_folder_handler.path_for_iteration(self.current_iter, 'points.csv'))
        utils.io.landmark.save_points_csv(landmarks_no_postprocessing, self.output_folder_handler.path_for_iteration(self.current_iter, 'points_no_postprocessing.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            summary_values = OrderedDict()
            if self.evaluate_landmarks_postprocessing:
                print(landmark_statistics.get_pe_overview_string())
                print(landmark_statistics.get_correct_id_string(20.0))
                overview_string = landmark_statistics.get_overview_string([2, 2.5, 3, 4, 10, 20], 10, 20.0)
                utils.io.text.save_string_txt(overview_string, self.output_folder_handler.path_for_iteration(self.current_iter, 'eval.txt'))
                summary_values.update(OrderedDict(zip(['pe_mean', 'pe_stdev', 'pe_median', 'num_correct'], list(landmark_statistics.get_pe_statistics()) + [landmark_statistics.get_num_correct_id(20)])))
            print(landmark_statistics_no_postprocessing.get_pe_overview_string())
            print(landmark_statistics_no_postprocessing.get_correct_id_string(20.0))
            overview_string = landmark_statistics_no_postprocessing.get_overview_string([2, 2.5, 3, 4, 10, 20], 10, 20.0)
            utils.io.text.save_string_txt(overview_string, self.output_folder_handler.path_for_iteration(self.current_iter, 'eval_no_postprocessing.txt'))
            summary_values.update(OrderedDict(zip(['pe_mean_np', 'pe_stdev_np', 'pe_median_np', 'num_correct_np'], list(landmark_statistics_no_postprocessing.get_pe_statistics()) + [landmark_statistics_no_postprocessing.get_num_correct_id(20)])))
            self.loss_metric_logger_val.update_metrics(summary_values)

            # finalize loss values
        self.loss_metric_logger_val.finalize(self.current_iter)
Пример #18
0
    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            if self.has_validation_groundtruth:
                groundtruth_landmarks = datasources['landmarks']
                groundtruth_landmark = [
                    get_mean_landmark(groundtruth_landmarks)
                ]
            input_image = datasources['image']

            image, prediction, transformation = self.test_full_image(
                dataset_entry)
            predictions_sitk = utils.sitk_image.transform_np_output_to_sitk_input(
                output_image=prediction,
                output_spacing=self.image_spacing,
                channel_axis=channel_axis,
                input_image_sitk=input_image,
                transform=transformation,
                interpolator='linear',
                output_pixel_type=sitk.sitkFloat32)
            if self.save_output_images:
                if self.save_output_images_as_uint:
                    image_normalization = 'min_max'
                    heatmap_normalization = (0, 1)
                    output_image_type = np.uint8
                else:
                    image_normalization = None
                    heatmap_normalization = None
                    output_image_type = np.float32
                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                utils.io.image.write_multichannel_np(
                    image,
                    self.output_file_for_current_iteration(current_id +
                                                           '_input.mha'),
                    output_normalization_mode=image_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                utils.io.image.write_multichannel_np(
                    prediction,
                    self.output_file_for_current_iteration(current_id +
                                                           '_prediction.mha'),
                    output_normalization_mode=heatmap_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                #utils.io.image.write(predictions_sitk[0], self.output_file_for_current_iteration(current_id + '_prediction_original.mha'))

            predictions_com = input_image.TransformContinuousIndexToPhysicalPoint(
                list(
                    reversed(
                        utils.np_image.center_of_mass(
                            utils.sitk_np.sitk_to_np_no_copy(
                                predictions_sitk[0])))))
            current_landmark = [Landmark(predictions_com)]
            landmarks[current_id] = current_landmark

            if self.has_validation_groundtruth:
                landmark_statistics.add_landmarks(current_id, current_landmark,
                                                  groundtruth_landmark)

            print_progress_bar(i,
                               num_entries,
                               prefix='Testing ',
                               suffix=' complete')

        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('points.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            print(landmark_statistics.get_pe_overview_string())
            summary_values = OrderedDict(
                zip(self.point_statistics_names,
                    list(landmark_statistics.get_pe_statistics())))

            # finalize loss values
            self.val_loss_aggregator.finalize(self.current_iter,
                                              summary_values)
            overview_string = landmark_statistics.get_overview_string(
                [2, 2.5, 3, 4, 10, 20], 10)
            utils.io.text.save_string_txt(
                overview_string,
                self.output_file_for_current_iteration('eval.txt'))
    def data_generators(self, iterator, datasources, transformation,
                        image_post_processing,
                        random_translation_single_landmark, image_size):
        """
        Returns the data generators that process one input. See datasources() for dict values.
        :param datasources: datasources dict.
        :param transformation: transformation.
        :param image_post_processing: The np postprocessing function for the image data generator.
        :return: A dict of data generators.
        """
        generators_dict = {}
        generators_dict['image'] = ImageGenerator(
            self.dim,
            image_size,
            self.image_spacing,
            interpolator='linear',
            post_processing_np=image_post_processing,
            data_format=self.data_format,
            resample_default_pixel_value=self.image_default_pixel_value,
            name='image',
            parents=[datasources['image'], transformation])
        if self.generate_landmark_mask:
            generators_dict['landmark_mask'] = ImageGenerator(
                self.dim,
                image_size,
                self.image_spacing,
                interpolator='nearest',
                data_format=self.data_format,
                resample_default_pixel_value=0,
                name='landmark_mask',
                parents=[datasources['landmark_mask'], transformation])
        if self.generate_labels or self.generate_single_vertebrae:
            generators_dict['labels'] = ImageGenerator(
                self.dim,
                image_size,
                self.image_spacing,
                interpolator='nearest',
                post_processing_np=self.split_labels,
                data_format=self.data_format,
                name='labels',
                parents=[datasources['labels'], transformation])
        if self.generate_heatmaps or self.generate_spine_heatmap:
            generators_dict['heatmaps'] = LandmarkGeneratorHeatmap(
                self.dim,
                image_size,
                self.image_spacing,
                sigma=self.heatmap_sigma,
                scale_factor=1.0,
                normalize_center=True,
                data_format=self.data_format,
                name='heatmaps',
                parents=[datasources['landmarks'], transformation])
        if self.generate_landmarks:
            generators_dict['landmarks'] = LandmarkGenerator(
                self.dim,
                image_size,
                self.image_spacing,
                data_format=self.data_format,
                name='landmarks',
                parents=[datasources['landmarks'], transformation])
        if self.generate_single_vertebrae_heatmap:
            single_landmark = LambdaNode(
                lambda id_dict, landmarks: landmarks[int(id_dict[
                    'landmark_id']):int(id_dict['landmark_id']) + 1],
                name='single_landmark',
                parents=[iterator, datasources['landmarks']])
            if random_translation_single_landmark:
                single_landmark = LambdaNode(
                    lambda l: [
                        Landmark(
                            l[0].coords + float_uniform(
                                -self.random_translation_single_landmark, self.
                                random_translation_single_landmark, [self.dim
                                                                     ]), True)
                    ],
                    name='single_landmark_translation',
                    parents=[single_landmark])
            generators_dict['single_heatmap'] = LandmarkGeneratorHeatmap(
                self.dim,
                image_size,
                self.image_spacing,
                sigma=self.heatmap_sigma,
                scale_factor=1.0,
                normalize_center=True,
                data_format=self.data_format,
                name='single_heatmap',
                parents=[single_landmark, transformation])
        if self.generate_single_vertebrae:
            if self.data_format == 'channels_first':
                generators_dict['single_label'] = LambdaNode(
                    lambda id_dict, images: images[int(id_dict[
                        'landmark_id']) + 1:int(id_dict['landmark_id']) + 2,
                                                   ...],
                    name='single_label',
                    parents=[iterator, generators_dict['labels']])
            else:
                generators_dict['single_label'] = LambdaNode(
                    lambda id_dict, images: images[...,
                                                   int(id_dict['landmark_id'])
                                                   + 1:int(id_dict[
                                                       'landmark_id']) + 2],
                    name='single_label',
                    parents=[iterator, generators_dict['labels']])
        if self.generate_spine_heatmap:
            generators_dict['spine_heatmap'] = LambdaNode(
                lambda images: gaussian(np.sum(images,
                                               axis=0 if self.data_format ==
                                               'channels_first' else -1,
                                               keepdims=True),
                                        sigma=self.spine_heatmap_sigma),
                name='spine_heatmap',
                parents=[generators_dict['heatmaps']])

        return generators_dict
    def data_generators(self, iterator, datasources, transformation, image_post_processing, random_translation_single_landmark, image_size, crop=False):
        """
        Returns the data generators that process one input. See datasources() for dict values.
        :param datasources: datasources dict.
        :param transformation: transformation.
        :param image_post_processing: The np postprocessing function for the image data generator.
        :return: A dict of data generators.
        """
        generators_dict = {}
        kwparents = {'output_size': image_size}
        image_datasource = datasources['image'] if not crop else LambdaNode(self.landmark_based_crop, name='image_cropped', kwparents={'image': datasources['image'], 'landmarks': datasources['landmarks']})
        generators_dict['image'] = ImageGenerator(self.dim,
                                                  None,
                                                  self.image_spacing,
                                                  interpolator='linear',
                                                  post_processing_np=image_post_processing,
                                                  data_format=self.data_format,
                                                  resample_default_pixel_value=self.image_default_pixel_value,
                                                  np_pixel_type=self.output_image_type,
                                                  name='image',
                                                  parents=[image_datasource, transformation],
                                                  kwparents=kwparents)
        # generators_dict['image'] = ImageGenerator(self.dim,
        #                                           None,
        #                                           self.image_spacing,
        #                                           interpolator='linear',
        #                                           post_processing_np=image_post_processing,
        #                                           data_format=self.data_format,
        #                                           resample_default_pixel_value=self.image_default_pixel_value,
        #                                           np_pixel_type=self.output_image_type,
        #                                           name='image_cropped',
        #                                           parents=[LambdaNode(self.landmark_based_crop, name='image_cropped', kwparents={'image': datasources['image'], 'landmarks': datasources['landmarks']}), transformation],
        #                                           kwparents=kwparents)
        if self.generate_landmark_mask:
            generators_dict['landmark_mask'] = ImageGenerator(self.dim,
                                                              None,
                                                              self.image_spacing,
                                                              interpolator='nearest',
                                                              data_format=self.data_format,
                                                              resample_default_pixel_value=0,
                                                              name='landmark_mask',
                                                              parents=[datasources['landmark_mask'], transformation],
                                                              kwparents=kwparents)
        if self.generate_labels:
            generators_dict['labels'] = ImageGenerator(self.dim,
                                                       None,
                                                       self.image_spacing,
                                                       interpolator='nearest',
                                                       post_processing_np=self.split_labels,
                                                       data_format=self.data_format,
                                                       name='labels',
                                                       parents=[datasources['labels'], transformation],
                                                       kwparents=kwparents)
        if self.generate_heatmaps or self.generate_spine_heatmap:
            generators_dict['heatmaps'] = LandmarkGeneratorHeatmap(self.dim,
                                                                   None,
                                                                   self.image_spacing,
                                                                   sigma=self.heatmap_sigma,
                                                                   scale_factor=1.0,
                                                                   normalize_center=True,
                                                                   data_format=self.data_format,
                                                                   name='heatmaps',
                                                                   parents=[datasources['landmarks'], transformation],
                                                                   kwparents=kwparents)
        if self.generate_landmarks:
            generators_dict['landmarks'] = LandmarkGenerator(self.dim,
                                                             None,
                                                             self.image_spacing,
                                                             data_format=self.data_format,
                                                             name='landmarks',
                                                             parents=[datasources['landmarks'], transformation],
                                                             kwparents=kwparents)
        if self.generate_single_vertebrae_heatmap:
            single_landmark = LambdaNode(lambda id_dict, landmarks: landmarks[int(id_dict['landmark_id']):int(id_dict['landmark_id']) + 1],
                                         name='single_landmark',
                                         parents=[iterator, datasources['landmarks']])
            if random_translation_single_landmark:
                single_landmark = LambdaNode(lambda l: [Landmark(l[0].coords + float_uniform(-self.random_translation_single_landmark, self.random_translation_single_landmark, [self.dim]), True)],
                                             name='single_landmark_translation',
                                             parents=[single_landmark])
            generators_dict['single_heatmap'] = LandmarkGeneratorHeatmap(self.dim,
                                                                         None,
                                                                         self.image_spacing,
                                                                         sigma=self.single_heatmap_sigma,
                                                                         scale_factor=1.0,
                                                                         normalize_center=True,
                                                                         data_format=self.data_format,
                                                                         np_pixel_type=self.output_image_type,
                                                                         name='single_heatmap',
                                                                         parents=[single_landmark, transformation],
                                                                         kwparents=kwparents)
        if self.generate_single_vertebrae:
            if self.generate_labels:
                if self.data_format == 'channels_first':
                    generators_dict['single_label'] = LambdaNode(lambda id_dict, images: images[int(id_dict['landmark_id']) + 1:int(id_dict['landmark_id']) + 2, ...],
                                                                 name='single_label',
                                                                 parents=[iterator, generators_dict['labels']])
                else:
                    generators_dict['single_label'] = LambdaNode(lambda id_dict, images: images[..., int(id_dict['landmark_id']) + 1:int(id_dict['landmark_id']) + 2],
                                                                 name='single_label',
                                                                 parents=[iterator, generators_dict['labels']])
            else:
                labels_unsmoothed = ImageGenerator(self.dim,
                                                   None,
                                                   self.image_spacing,
                                                   interpolator='nearest',
                                                   post_processing_np=None,
                                                   data_format=self.data_format,
                                                   name='labels_unsmoothed',
                                                   parents=[datasources['labels'], transformation],
                                                   kwparents=kwparents)
                generators_dict['single_label'] = LambdaNode(lambda id_dict, labels: self.split_and_smooth_single_label(labels, int(id_dict['landmark_id'])),
                                                             name='single_label',
                                                             parents=[iterator, labels_unsmoothed])
        if self.generate_spine_heatmap:
            generators_dict['spine_heatmap'] = LambdaNode(lambda images: normalize(gaussian(np.sum(images, axis=0 if self.data_format == 'channels_first' else -1, keepdims=True), sigma=self.spine_heatmap_sigma), out_range=(0, 1)),
                                                          name='spine_heatmap',
                                                          parents=[generators_dict['heatmaps']])

        return generators_dict
Пример #21
0
 else:
     ext_length = len('_ctd.json')
 filename_wo_folder_and_ext = filename_wo_folder[:-ext_length]
 image_id = filename_wo_folder_and_ext
 print(filename_wo_folder_and_ext)
 # get image meta data
 image_meta_data = read_meta_data(
     os.path.join(verse_dataset_folder, 'images_reoriented',
                  image_id + '.nii.gz'))
 spacing = np.array(image_meta_data.GetSpacing())
 origin = np.array(image_meta_data.GetOrigin())
 direction = np.array(image_meta_data.GetDirection()).reshape([3, 3])
 size = np.array(image_meta_data.GetSize())
 # placeholder for landmarks
 current_landmarks = [
     Landmark([np.nan] * 3, False, 1.0, 0.0)
     for _ in range(num_landmarks)
 ]
 with open(filename, 'r') as f:
     # load json file
     json_data = json.load(f)
     for landmark in json_data:
         # convert verse coordinate system to physical coordinates
         if verse2020:
             coords = np.array([
                 size[0] * spacing[0] - float(landmark['Z']),
                 float(landmark['X']),
                 size[2] * spacing[2] - float(landmark['Y'])
             ])
         else:
             coords = np.array([
Пример #22
0
    def test(self):
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3

        if len(self.load_model_filenames) == 1:
            self.network_loop.load_model_filename = self.load_model_filenames[
                0]
            self.network_loop.load_model()

        landmarks = {}
        for image_id in self.image_id_list:
            try:
                print(image_id)
                dataset_entry = self.dataset_val.get({'image_id': image_id})
                current_id = dataset_entry['id']['image_id']
                datasources = dataset_entry['datasources']
                input_image = datasources['image']

                image, prediction, transformation = self.test_full_image(
                    dataset_entry)

                predictions_sitk = utils.sitk_image.transform_np_output_to_sitk_input(
                    output_image=prediction,
                    output_spacing=self.image_spacing,
                    channel_axis=channel_axis,
                    input_image_sitk=input_image,
                    transform=transformation,
                    interpolator='linear',
                    output_pixel_type=sitk.sitkFloat32)

                if self.save_debug_images:
                    origin = transformation.TransformPoint(
                        np.zeros(3, np.float64))
                    heatmap_normalization_mode = (0, 1)
                    utils.io.image.write_multichannel_np(
                        image,
                        self.output_file_for_current_iteration(current_id +
                                                               '_input.mha'),
                        output_normalization_mode='min_max',
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_file_for_current_iteration(
                            current_id + '_prediction.mha'),
                        output_normalization_mode=heatmap_normalization_mode,
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write(
                        predictions_sitk[0],
                        self.output_file_for_current_iteration(
                            current_id + '_prediction_original.mha'))

                predictions_com = input_image.TransformContinuousIndexToPhysicalPoint(
                    list(
                        reversed(
                            utils.np_image.center_of_mass(
                                utils.sitk_np.sitk_to_np_no_copy(
                                    predictions_sitk[0])))))
                landmarks[current_id] = [Landmark(predictions_com)]
            except:
                print(traceback.format_exc())
                print('ERROR predicting', image_id)
                pass

        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('landmarks.csv'))