示例#1
0
    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        if self.use_spine_postprocessing:
            heatmap_maxima = HeatmapTest(channel_axis,
                                         False,
                                         return_multiple_maxima=True,
                                         min_max_distance=7,
                                         min_max_value=0.25,
                                         multiple_min_max_value_factor=0.1)
            spine_postprocessing = SpinePostprocessing(
                num_landmarks=self.num_landmarks,
                image_spacing=self.image_spacing)
        else:
            heatmap_maxima = HeatmapTest(channel_axis, False)

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            input_image = datasources['image']
            target_landmarks = datasources['landmarks']

            image, prediction, transformation = self.test_cropped_image(
                dataset_entry)

            if self.save_output_images:
                if self.save_output_images_as_uint:
                    image_normalization = 'min_max'
                    heatmap_normalization = (0, 1)
                    output_image_type = np.uint8
                else:
                    image_normalization = None
                    heatmap_normalization = None
                    output_image_type = np.float32
                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                utils.io.image.write_multichannel_np(
                    image,
                    self.output_file_for_current_iteration(current_id +
                                                           '_input.mha'),
                    output_normalization_mode=image_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                utils.io.image.write_multichannel_np(
                    prediction,
                    self.output_file_for_current_iteration(current_id +
                                                           '_prediction.mha'),
                    output_normalization_mode=heatmap_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)

            if self.use_spine_postprocessing:
                local_maxima_landmarks = heatmap_maxima.get_landmarks(
                    prediction, input_image, self.image_spacing,
                    transformation)
                landmark_sequence = spine_postprocessing.postprocess_landmarks(
                    local_maxima_landmarks, prediction.shape)
                landmarks[current_id] = landmark_sequence
            else:
                maxima_landmarks = heatmap_maxima.get_landmarks(
                    prediction, input_image, self.image_spacing,
                    transformation)
                landmarks[current_id] = maxima_landmarks

            if self.has_validation_groundtruth:
                landmark_statistics.add_landmarks(current_id,
                                                  landmark_sequence,
                                                  target_landmarks)

            print_progress_bar(i,
                               num_entries,
                               prefix='Testing ',
                               suffix=' complete')

        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('points.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            print(landmark_statistics.get_pe_overview_string())
            print(landmark_statistics.get_correct_id_string(20.0))
            summary_values = OrderedDict(
                zip(
                    self.point_statistics_names,
                    list(landmark_statistics.get_pe_statistics()) +
                    [landmark_statistics.get_correct_id(20)]))

            # finalize loss values
            self.val_loss_aggregator.finalize(self.current_iter,
                                              summary_values)
            overview_string = landmark_statistics.get_overview_string(
                [2, 2.5, 3, 4, 10, 20], 10, 20.0)
            utils.io.text.save_string_txt(
                overview_string,
                self.output_file_for_current_iteration('eval.txt'))
示例#2
0
    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            if self.has_validation_groundtruth:
                groundtruth_landmarks = datasources['landmarks']
                groundtruth_landmark = [
                    get_mean_landmark(groundtruth_landmarks)
                ]
            input_image = datasources['image']

            image, prediction, transformation = self.test_full_image(
                dataset_entry)
            predictions_sitk = utils.sitk_image.transform_np_output_to_sitk_input(
                output_image=prediction,
                output_spacing=self.image_spacing,
                channel_axis=channel_axis,
                input_image_sitk=input_image,
                transform=transformation,
                interpolator='linear',
                output_pixel_type=sitk.sitkFloat32)
            if self.save_output_images:
                if self.save_output_images_as_uint:
                    image_normalization = 'min_max'
                    heatmap_normalization = (0, 1)
                    output_image_type = np.uint8
                else:
                    image_normalization = None
                    heatmap_normalization = None
                    output_image_type = np.float32
                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                utils.io.image.write_multichannel_np(
                    image,
                    self.output_file_for_current_iteration(current_id +
                                                           '_input.mha'),
                    output_normalization_mode=image_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                utils.io.image.write_multichannel_np(
                    prediction,
                    self.output_file_for_current_iteration(current_id +
                                                           '_prediction.mha'),
                    output_normalization_mode=heatmap_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                #utils.io.image.write(predictions_sitk[0], self.output_file_for_current_iteration(current_id + '_prediction_original.mha'))

            predictions_com = input_image.TransformContinuousIndexToPhysicalPoint(
                list(
                    reversed(
                        utils.np_image.center_of_mass(
                            utils.sitk_np.sitk_to_np_no_copy(
                                predictions_sitk[0])))))
            current_landmark = [Landmark(predictions_com)]
            landmarks[current_id] = current_landmark

            if self.has_validation_groundtruth:
                landmark_statistics.add_landmarks(current_id, current_landmark,
                                                  groundtruth_landmark)

            print_progress_bar(i,
                               num_entries,
                               prefix='Testing ',
                               suffix=' complete')

        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('points.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            print(landmark_statistics.get_pe_overview_string())
            summary_values = OrderedDict(
                zip(self.point_statistics_names,
                    list(landmark_statistics.get_pe_statistics())))

            # finalize loss values
            self.val_loss_aggregator.finalize(self.current_iter,
                                              summary_values)
            overview_string = landmark_statistics.get_overview_string(
                [2, 2.5, 3, 4, 10, 20], 10)
            utils.io.text.save_string_txt(
                overview_string,
                self.output_file_for_current_iteration('eval.txt'))
示例#3
0
    def test(self):
        print('Testing...')
        if self.data_format == 'channels_first':
            np_channel_index = 0
        else:
            np_channel_index = 3
        heatmap_maxima = HeatmapTest(np_channel_index, False)
        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        for i in range(self.dataset_val.num_entries()):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            image_datasource = datasources['image_datasource']
            landmarks_datasource = datasources['landmarks_datasource']

            if not self.cropped_training:
                image, heatmaps, heatmap_transform = self.test_full_image(
                    dataset_entry)
            else:
                image, heatmaps, heatmap_transform = self.test_cropped_image(
                    dataset_entry)

            utils.io.image.write_np(
                ShiftScaleClamp(scale=255, clamp_min=0,
                                clamp_max=255)(heatmaps).astype(np.uint8),
                self.output_file_for_current_iteration(current_id +
                                                       '_heatmaps.mha'))
            utils.io.image.write_np(
                image,
                self.output_file_for_current_iteration(current_id +
                                                       '_image.mha'))

            predicted_landmarks = heatmap_maxima.get_landmarks(
                heatmaps, image_datasource, self.image_spacing,
                heatmap_transform)
            landmarks[current_id] = predicted_landmarks
            landmark_statistics.add_landmarks(current_id, predicted_landmarks,
                                              landmarks_datasource)

            tensorflow_train.utils.tensorflow_util.print_progress_bar(
                i,
                self.dataset_val.num_entries(),
                prefix='Testing ',
                suffix=' complete')

        tensorflow_train.utils.tensorflow_util.print_progress_bar(
            self.dataset_val.num_entries(),
            self.dataset_val.num_entries(),
            prefix='Testing ',
            suffix=' complete')
        print(landmark_statistics.get_pe_overview_string())
        print(landmark_statistics.get_correct_id_string(20.0))
        summary_values = OrderedDict(
            zip(
                self.point_statistics_names,
                list(landmark_statistics.get_pe_statistics()) +
                [landmark_statistics.get_correct_id(20)]))

        # finalize loss values
        self.val_loss_aggregator.finalize(self.current_iter, summary_values)
        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('points.csv'))
        overview_string = landmark_statistics.get_overview_string(
            [2, 2.5, 3, 4, 10, 20], 10, 20.0)
        utils.io.text.save_string_txt(
            overview_string,
            self.output_file_for_current_iteration('eval.txt'))
    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')
        vis = LandmarkVisualizationMatplotlib(dim=3,
                                              annotations=dict([(i, f'C{i + 1}') for i in range(7)] +        # 0-6: C1-C7
                                                               [(i, f'T{i - 6}') for i in range(7, 19)] +    # 7-18: T1-12
                                                               [(i, f'L{i - 18}') for i in range(19, 25)] +  # 19-24: L1-6
                                                               [(25, 'T13')]))                               # 25: T13

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        heatmap_maxima = HeatmapTest(channel_axis,
                                     False,
                                     return_multiple_maxima=True,
                                     min_max_value=0.05,
                                     smoothing_sigma=2.0)

        with open('possible_successors.pickle', 'rb') as f:
            possible_successors = pickle.load(f)
        with open('units_distances.pickle', 'rb') as f:
            offsets_mean, distances_mean, distances_std = pickle.load(f)
        spine_postprocessing = SpinePostprocessingGraph(num_landmarks=self.num_landmarks,
                                                        possible_successors=possible_successors,
                                                        offsets_mean=offsets_mean,
                                                        distances_mean=distances_mean,
                                                        distances_std=distances_std,
                                                        bias=2.0,
                                                        l=0.2)

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        landmark_statistics_no_postprocessing = LandmarkStatistics()
        landmarks_no_postprocessing = {}
        all_local_maxima_landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for _ in tqdm(range(num_entries), desc='Testing'):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            input_image = datasources['image']
            if self.has_validation_groundtruth:
                target_landmarks = datasources['landmarks']
            else:
                target_landmarks = None

            image, prediction, prediction_local, prediction_spatial, transformation = self.test_cropped_image(dataset_entry)

            origin = transformation.TransformPoint(np.zeros(3, np.float64))
            if self.save_output_images:
                heatmap_normalization_mode = (-1, 1)
                image_type = np.uint8
                utils.io.image.write_multichannel_np(image,self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_input.mha'), output_normalization_mode='min_max', sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction_local, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction_local.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction_spatial, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction_spatial.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)

            local_maxima_landmarks = heatmap_maxima.get_landmarks(prediction, input_image, self.image_spacing, transformation)

            # landmarks without postprocessing are the first local maxima (with the largest value)
            curr_landmarks_no_postprocessing = [l[0] if len(l) > 0 else Landmark(coords=[np.nan] * 3, is_valid=False)  for l in local_maxima_landmarks]
            landmarks_no_postprocessing[current_id] = curr_landmarks_no_postprocessing

            if self.has_validation_groundtruth:
                landmark_statistics_no_postprocessing.add_landmarks(current_id, curr_landmarks_no_postprocessing, target_landmarks)
                vis.visualize_landmark_projections(input_image, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_gt.png'))
                vis.visualize_prediction_groundtruth_projections(input_image, curr_landmarks_no_postprocessing, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks.png'))
            else:
                vis.visualize_landmark_projections(input_image, curr_landmarks_no_postprocessing, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks.png'))

            if self.evaluate_landmarks_postprocessing:
                try:
                    local_maxima_landmarks = add_landmarks_from_neighbors(local_maxima_landmarks)
                    curr_landmarks = spine_postprocessing.solve_local_heatmap_maxima(local_maxima_landmarks)
                    curr_landmarks = reshift_landmarks(curr_landmarks)
                    curr_landmarks = filter_landmarks_top_bottom(curr_landmarks, input_image)
                except Exception:
                    print('error in postprocessing', current_id)
                    curr_landmarks = curr_landmarks_no_postprocessing
                landmarks[current_id] = curr_landmarks

                if self.has_validation_groundtruth:
                    landmark_statistics.add_landmarks(current_id, curr_landmarks, target_landmarks)
                    vis.visualize_prediction_groundtruth_projections(input_image, curr_landmarks, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_pp.png'))
                else:
                    vis.visualize_landmark_projections(input_image, curr_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_pp.png'))

        utils.io.landmark.save_points_csv(landmarks, self.output_folder_handler.path_for_iteration(self.current_iter, 'points.csv'))
        utils.io.landmark.save_points_csv(landmarks_no_postprocessing, self.output_folder_handler.path_for_iteration(self.current_iter, 'points_no_postprocessing.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            summary_values = OrderedDict()
            if self.evaluate_landmarks_postprocessing:
                print(landmark_statistics.get_pe_overview_string())
                print(landmark_statistics.get_correct_id_string(20.0))
                overview_string = landmark_statistics.get_overview_string([2, 2.5, 3, 4, 10, 20], 10, 20.0)
                utils.io.text.save_string_txt(overview_string, self.output_folder_handler.path_for_iteration(self.current_iter, 'eval.txt'))
                summary_values.update(OrderedDict(zip(['pe_mean', 'pe_stdev', 'pe_median', 'num_correct'], list(landmark_statistics.get_pe_statistics()) + [landmark_statistics.get_num_correct_id(20)])))
            print(landmark_statistics_no_postprocessing.get_pe_overview_string())
            print(landmark_statistics_no_postprocessing.get_correct_id_string(20.0))
            overview_string = landmark_statistics_no_postprocessing.get_overview_string([2, 2.5, 3, 4, 10, 20], 10, 20.0)
            utils.io.text.save_string_txt(overview_string, self.output_folder_handler.path_for_iteration(self.current_iter, 'eval_no_postprocessing.txt'))
            summary_values.update(OrderedDict(zip(['pe_mean_np', 'pe_stdev_np', 'pe_median_np', 'num_correct_np'], list(landmark_statistics_no_postprocessing.get_pe_statistics()) + [landmark_statistics_no_postprocessing.get_num_correct_id(20)])))
            self.loss_metric_logger_val.update_metrics(summary_values)

            # finalize loss values
        self.loss_metric_logger_val.finalize(self.current_iter)