Exemple #1
0
    def make_viz_pipeline(self, data_provider: Provider,
                          keras_model: tf.keras.Model) -> Pipeline:
        """Create visualization pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.
            keras_model: A `tf.keras.Model` that can be used for inference.

        Returns:
            A `Pipeline` instance configured to fetch data and run inference to generate
            predictions useful for visualization during training.
        """
        pipeline = self.make_base_pipeline(data_provider=data_provider)
        pipeline += Prefetcher()
        pipeline += Repeater()
        pipeline += KerasModelPredictor(
            keras_model=keras_model,
            model_input_keys="image",
            model_output_keys="predicted_confidence_maps",
        )
        pipeline += GlobalPeakFinder(
            confmaps_key="predicted_confidence_maps",
            peaks_key="predicted_points",
            peak_vals_key="predicted_confidences",
            confmaps_stride=self.single_instance_confmap_head.output_stride,
            peak_threshold=0.2,
        )
        return pipeline
Exemple #2
0
    def make_viz_pipeline(self, data_provider: Provider,
                          keras_model: tf.keras.Model) -> Pipeline:
        """Create visualization pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.
            keras_model: A `tf.keras.Model` that can be used for inference.

        Returns:
            A `Pipeline` instance configured to fetch data and run inference to generate
            predictions useful for visualization during training.
        """
        pipeline = self.make_base_pipeline(data_provider=data_provider)
        pipeline += Prefetcher()
        pipeline += Repeater()
        if self.optimization_config.augmentation_config.random_crop:
            pipeline += RandomCropper(
                crop_height=self.optimization_config.augmentation_config.
                random_crop_height,
                crop_width=self.optimization_config.augmentation_config.
                random_crop_width,
            )
        pipeline += KerasModelPredictor(
            keras_model=keras_model,
            model_input_keys="image",
            model_output_keys="predicted_centroid_confidence_maps",
        )
        pipeline += LocalPeakFinder(
            confmaps_stride=self.centroid_confmap_head.output_stride,
            peak_threshold=0.2,
            confmaps_key="predicted_centroid_confidence_maps",
            peaks_key="predicted_centroids",
            peak_vals_key="predicted_centroid_confidences",
            peak_sample_inds_key="predicted_centroid_sample_inds",
            peak_channel_inds_key="predicted_centroid_channel_inds",
        )
        return pipeline