예제 #1
0
    def merge_tiled_instances(self, tiled_instances):
        """
        Merges neighboring tiled instances.
        :param tiled_instances: list of instances.
        :return: Merged instances.
        """
        # initialize sizes based on data_format
        instances_size_np = [2] + list(reversed(self.image_size))
        full_instances_size_np = [2] + list(reversed(self.test_image_size))
        inc = [0] + list(reversed(self.tiled_increment))
        # initialize on image tiler for the input and a list of image tilers for the embeddings
        instance_tiler = ImageTiler(full_instances_size_np,
                                    instances_size_np,
                                    inc,
                                    True,
                                    0,
                                    output_image_dtype=np.uint16)
        instance_merger = InstanceMerger(
            ignore_border=self.instances_ignore_border)

        for i, instance_tiler in enumerate(instance_tiler):
            current_instance_pair = tiled_instances[i]
            instance_tiler.set_current_data(
                current_instance_pair,
                instance_merger.merge_as_larger_instances,
                merge_whole_image=True)
        instances = instance_tiler.output_image

        return instances
예제 #2
0
    def test_cropped_image(self, full_image):
        image_size_np = [1] + list(
            reversed(self.image_size
                     )) if self.data_format == 'channels_first' else list(
                         reversed(self.image_size)) + [1]
        labels_size_np = [self.num_labels] + list(
            reversed(self.image_size
                     )) if self.data_format == 'channels_first' else list(
                         reversed(self.image_size)) + [self.num_labels]
        predictions_full_size_np = [self.num_labels] + list(
            full_image.shape[1:]
        ) if self.data_format == 'channels_first' else list(
            full_image.shape[:-1]) + [self.num_labels]
        cropped_inc = [
            0
        ] + self.cropped_inc if self.data_format == 'channels_first' else self.cropped_inc + [
            0
        ]
        image_tiler = ImageTiler(full_image.shape, image_size_np, cropped_inc,
                                 True, -1)
        prediction_tiler = ImageTiler(predictions_full_size_np, labels_size_np,
                                      cropped_inc, True, 0)

        for image_tiler, prediction_tiler in zip(image_tiler,
                                                 prediction_tiler):
            current_image = image_tiler.get_current_data(full_image)
            feed_dict = {self.data_val: np.expand_dims(current_image, axis=0)}
            run_tuple = self.sess.run((self.prediction_val, ),
                                      feed_dict=feed_dict)
            prediction = np.squeeze(run_tuple[0], axis=0)
            image_tiler.set_current_data(current_image)
            prediction_tiler.set_current_data(prediction)

        return prediction_tiler.output_image
예제 #3
0
    def test_cropped_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network. Performs cropped prediction and merges outputs as maxima.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), target heatmaps (np.array), predicted heatmaps,  transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        transformation = transformations['image']

        image_size_np = [1] + list(reversed(self.image_size))
        labels_size_np = [self.num_landmarks] + list(reversed(self.image_size))
        full_image = generators['image']
        landmarks = generators['landmarks']
        image_tiler = ImageTiler(full_image.shape, image_size_np,
                                 self.cropped_inc, True, -1)
        landmark_tiler = LandmarkTiler(full_image.shape, image_size_np,
                                       self.cropped_inc)
        prediction_tiler = ImageTiler(
            (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
            self.cropped_inc, True, 0)

        for image_tiler, landmark_tiler, prediction_tiler in zip(
                image_tiler, landmark_tiler, prediction_tiler):
            current_image = image_tiler.get_current_data(full_image)
            current_landmarks = landmark_tiler.get_current_data(landmarks)
            if self.has_validation_groundtruth:
                feed_dict = {
                    self.data_val:
                    np.expand_dims(current_image, axis=0),
                    self.target_landmarks_val:
                    np.expand_dims(current_landmarks, axis=0)
                }
                run_tuple = self.sess.run(
                    (self.prediction_val, self.loss_val) +
                    self.val_loss_aggregator.get_update_ops(),
                    feed_dict=feed_dict)
            else:
                feed_dict = {
                    self.data_val: np.expand_dims(current_image, axis=0)
                }
                run_tuple = self.sess.run((self.prediction_val, ),
                                          feed_dict=feed_dict)
            prediction = np.squeeze(run_tuple[0], axis=0)
            image_tiler.set_current_data(current_image)
            prediction_tiler.set_current_data(prediction)

        return image_tiler.output_image, prediction_tiler.output_image, transformation
예제 #4
0
    def test_cropped_image(self,
                           dataset_entry,
                           return_all_intermediate_embeddings=False):
        generators = dataset_entry['generators']
        full_image = generators['image']
        # initialize sizes based on data_format
        fetches = self.embeddings_cropped_val
        if self.data_format == 'channels_first':
            image_size_np = [1, self.num_frames] + list(
                reversed(self.image_size))
            full_image_size_np = list(full_image.shape)
            embeddings_size_np = [self.num_embeddings, self.num_frames] + list(
                reversed(self.image_size))
            full_embeddings_size_np = [self.num_embeddings] + list(
                full_image.shape[1:])
            inc = [0, 0] + list(reversed(self.tiled_increment))
        else:
            image_size_np = list(reversed(self.image_size)) + [1]
            full_image_size_np = list(full_image.shape)
            embeddings_size_np = list(reversed(
                self.image_size)) + [self.num_embeddings]
            full_embeddings_size_np = list(
                full_image.shape[0:2]) + [self.num_embeddings]
            inc = list(reversed(self.tiled_increment)) + [0]
        # initialize on image tiler for the input and a list of image tilers for the embeddings
        image_tiler = ImageTiler(full_image_size_np, image_size_np, inc, True,
                                 -1)
        embeddings_tilers = tuple([
            ImageTiler(full_embeddings_size_np, embeddings_size_np, inc, True,
                       -1) for _ in range(len(self.embeddings_cropped_val))
        ])

        all_intermediate_embeddings = []
        for state_index, all_tilers in enumerate(
                zip(*((image_tiler, ) + embeddings_tilers))):
            image_tiler = all_tilers[0]
            embeddings_tilers = all_tilers[1:]
            current_image = image_tiler.get_current_data(full_image)
            feed_dict = {self.data_val: np.expand_dims(current_image, axis=0)}
            run_tuple = self.sess.run(fetches, feed_dict)
            image_tiler.set_current_data(current_image)
            for i, embeddings_tiler in enumerate(embeddings_tilers):
                embeddings = np.squeeze(run_tuple[i], axis=0)
                if return_all_intermediate_embeddings and i == len(
                        embeddings_tilers) - 1:
                    all_intermediate_embeddings.append(embeddings)
                embeddings_tiler.set_current_data(embeddings)

        embeddings = [
            embeddings_tiler.output_image
            for embeddings_tiler in embeddings_tilers
        ]

        if return_all_intermediate_embeddings:
            return embeddings, all_intermediate_embeddings
        else:
            return embeddings
예제 #5
0
    def test_cropped_image(self, dataset_entry):
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        heatmap_transform = transformations['image']

        image_size_np = [1] + list(reversed(self.image_size))
        heatmap_size_np = [self.num_landmarks] + list(reversed(
            self.image_size))
        full_image = generators['image']
        landmarks = generators['landmarks']
        image_tiler = ImageTiler(full_image.shape, image_size_np,
                                 self.cropped_inc, True, -1)
        landmark_tiler = LandmarkTiler(full_image.shape, image_size_np,
                                       self.cropped_inc)
        heatmap_tiler = ImageTiler(
            (self.num_landmarks, ) + full_image.shape[1:], heatmap_size_np,
            self.cropped_inc, True, 0)

        for image_tiler, landmark_tiler, heatmap_tiler in zip(
                image_tiler, landmark_tiler, heatmap_tiler):
            current_image = image_tiler.get_current_data(full_image)
            current_landmarks = landmark_tiler.get_current_data(landmarks)
            feed_dict = {
                self.image_val:
                np.expand_dims(current_image, axis=0),
                self.target_landmarks_val:
                np.expand_dims(current_landmarks, axis=0)
            }
            run_tuple = self.sess.run(
                (self.heatmaps_val, self.target_heatmaps_val, self.loss_val) +
                self.val_loss_aggregator.get_update_ops(), feed_dict)
            prediction = np.squeeze(run_tuple[0], axis=0)
            image_tiler.set_current_data(current_image)
            heatmap_tiler.set_current_data(prediction)

        return image_tiler.output_image, heatmap_tiler.output_image, heatmap_transform
    def test_cropped_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network. Performs cropped prediction and merges outputs as maxima.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), target heatmaps (np.array), predicted heatmaps,  transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        transformation = transformations['image']

        full_image = generators['image']
        if self.has_validation_groundtruth:
            landmarks = generators['landmarks']

        image_size_for_tilers = np.minimum(full_image.shape[1:], list(reversed(self.max_image_size_for_cropped_test))).tolist()

        image_size_np = [1] + image_size_for_tilers
        labels_size_np = [self.num_landmarks] + image_size_for_tilers
        image_tiler = ImageTiler(full_image.shape, image_size_np, self.cropped_inc, True, -1)
        landmark_tiler = LandmarkTiler(full_image.shape, image_size_np, self.cropped_inc)
        prediction_tiler = ImageTiler((self.num_landmarks,) + full_image.shape[1:], labels_size_np, self.cropped_inc, True, -np.inf)
        prediction_local_tiler = ImageTiler((self.num_landmarks,) + full_image.shape[1:], labels_size_np, self.cropped_inc, True, -np.inf)
        prediction_spatial_tiler = ImageTiler((self.num_landmarks,) + full_image.shape[1:], labels_size_np, self.cropped_inc, True, -np.inf)
        for image_tiler, landmark_tiler, prediction_tiler, prediction_local_tiler, prediction_spatial_tiler in zip(image_tiler, landmark_tiler, prediction_tiler, prediction_local_tiler, prediction_spatial_tiler):
            current_image = image_tiler.get_current_data(full_image)
            if self.has_validation_groundtruth:
                current_landmarks = landmark_tiler.get_current_data(landmarks)
                (prediction, prediction_local, prediction_spatial), losses = self.call_model_and_loss(np.expand_dims(current_image, axis=0),
                                                                                                      np.expand_dims(current_landmarks, axis=0), training=False)
                self.loss_metric_logger_val.update_metrics(losses)
            else:
                prediction, prediction_local, prediction_spatial = self.model(np.expand_dims(current_image, axis=0), training=False)
            image_tiler.set_current_data(current_image)
            prediction_tiler.set_current_data(np.squeeze(prediction, axis=0))
            prediction_local_tiler.set_current_data(np.squeeze(prediction_local, axis=0))
            prediction_spatial_tiler.set_current_data(np.squeeze(prediction_spatial, axis=0))

        return image_tiler.output_image, prediction_tiler.output_image, prediction_local_tiler.output_image, prediction_spatial_tiler.output_image, transformation
예제 #7
0
    def test_cropped_image(self,
                           dataset_entry,
                           current_lstm_states,
                           return_all_intermediate_embeddings=False):
        """
        Tests the whole image by cropping the input image.
        :param dataset_entry: The dataset entry.
        :param current_lstm_states: The current lstm states per tile.
        :param return_all_intermediate_embeddings: If true, return embeddings for all tiles.
        :return: merged embeddings, (list of all intermediate embeddings), list of next lstm states per tile
        """
        generators = dataset_entry['generators']
        full_image = generators['image']
        # initialize sizes based on data_format
        fetches = self.embeddings_cropped_val + self.lstm_output_states_cropped_val
        if self.data_format == 'channels_first':
            image_size_np = [1] + list(reversed(self.image_size))
            full_image_size_np = list(full_image.shape)
            embeddings_size_np = [self.num_embeddings] + list(
                reversed(self.image_size))
            full_embeddings_size_np = [self.num_embeddings] + list(
                full_image.shape[1:])
            inc = [0] + list(reversed(self.tiled_increment))
        else:
            image_size_np = list(reversed(self.image_size)) + [1]
            full_image_size_np = list(full_image.shape)
            embeddings_size_np = list(reversed(
                self.image_size)) + [self.num_embeddings]
            full_embeddings_size_np = list(
                full_image.shape[0:2]) + [self.num_embeddings]
            inc = list(reversed(self.tiled_increment)) + [0]
        # initialize on image tiler for the input and a list of image tilers for the embeddings
        image_tiler = ImageTiler(full_image_size_np, image_size_np, inc, True,
                                 -1)
        embeddings_tilers = tuple([
            ImageTiler(full_embeddings_size_np, embeddings_size_np, inc, True,
                       -1) for _ in range(len(self.embeddings_cropped_val))
        ])

        next_lstm_states = []
        all_intermediate_embeddings = []
        for state_index, all_tilers in enumerate(
                zip(*((image_tiler, ) + embeddings_tilers))):
            image_tiler = all_tilers[0]
            embeddings_tilers = all_tilers[1:]
            current_image = image_tiler.get_current_data(full_image)
            feed_dict = {
                self.data_cropped_val: np.expand_dims(current_image, axis=0)
            }
            if len(current_lstm_states) > 0:
                for i in range(len(self.lstm_input_states_cropped_val)):
                    feed_dict[self.lstm_input_states_cropped_val[
                        i]] = current_lstm_states[state_index][i]
            run_tuple = self.sess.run(fetches, feed_dict)
            image_tiler.set_current_data(current_image)
            for i, embeddings_tiler in enumerate(embeddings_tilers):
                embeddings = np.squeeze(run_tuple[i], axis=0)
                if return_all_intermediate_embeddings and i == len(
                        embeddings_tilers) - 1:
                    all_intermediate_embeddings.append(embeddings)
                embeddings_tiler.set_current_data(embeddings)
            current_next_lstm_states = run_tuple[
                len(self.embeddings_cropped_val
                    ):len(self.embeddings_cropped_val) +
                len(self.lstm_output_states_cropped_val)]
            next_lstm_states.append(current_next_lstm_states)

        embeddings = [
            embeddings_tiler.output_image
            for embeddings_tiler in embeddings_tilers
        ]

        if return_all_intermediate_embeddings:
            return embeddings, all_intermediate_embeddings, next_lstm_states
        else:
            return embeddings, next_lstm_states
예제 #8
0
    def test_cropped_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network. Performs cropped prediction and merges outputs as maxima.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), target heatmaps (np.array), predicted heatmaps,  transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        transformation = transformations['image']

        full_image = generators['image']

        if self.data_format == 'channels_first':
            image_size_for_tilers = np.minimum(
                full_image.shape[1:],
                list(reversed(self.max_image_size_for_cropped_test))).tolist()
            image_size_np = [1] + image_size_for_tilers
            labels_size_np = [self.num_landmarks] + image_size_for_tilers
            image_tiler = ImageTiler(full_image.shape, image_size_np,
                                     self.cropped_inc, True, -1)
            prediction_tiler = ImageTiler(
                (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_local_tiler = ImageTiler(
                (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_spatial_tiler = ImageTiler(
                (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
                self.cropped_inc, True, -np.inf)
        else:
            image_size_for_tilers = np.minimum(
                full_image.shape[:-1],
                list(reversed(self.max_image_size_for_cropped_test))).tolist()
            image_size_np = image_size_for_tilers + [1]
            labels_size_np = image_size_for_tilers + [self.num_landmarks]
            image_tiler = ImageTiler(full_image.shape, image_size_np,
                                     self.cropped_inc, True, -1)
            prediction_tiler = ImageTiler(
                full_image.shape[:-1] + (self.num_landmarks, ), labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_local_tiler = ImageTiler(
                full_image.shape[:-1] + (self.num_landmarks, ), labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_spatial_tiler = ImageTiler(
                full_image.shape[:-1] + (self.num_landmarks, ), labels_size_np,
                self.cropped_inc, True, -np.inf)

        for image_tiler, prediction_tiler, prediction_local_tiler, prediction_spatial_tiler in zip(
                image_tiler, prediction_tiler, prediction_local_tiler,
                prediction_spatial_tiler):
            current_image = image_tiler.get_current_data(full_image)
            predictions = []
            predictions_local = []
            predictions_spatial = []
            for load_model_filename in self.load_model_filenames:
                if len(self.load_model_filenames) > 1:
                    self.load_model(load_model_filename)
                prediction, prediction_local, prediction_spatial = self.call_model(
                    np.expand_dims(current_image, axis=0))
                predictions.append(prediction.numpy())
                predictions_local.append(prediction_local.numpy())
                predictions_spatial.append(prediction_spatial.numpy())
            prediction = np.mean(predictions, axis=0)
            prediction_local = np.mean(predictions_local, axis=0)
            prediction_spatial = np.mean(predictions_spatial, axis=0)
            image_tiler.set_current_data(current_image)
            prediction_tiler.set_current_data(np.squeeze(prediction, axis=0))
            prediction_local_tiler.set_current_data(
                np.squeeze(prediction_local, axis=0))
            prediction_spatial_tiler.set_current_data(
                np.squeeze(prediction_spatial, axis=0))

        return image_tiler.output_image, prediction_tiler.output_image, prediction_local_tiler.output_image, prediction_spatial_tiler.output_image, transformation