Example #1
0
    def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create base pipeline with input data only.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce input examples.
        """
        pipeline = Pipeline(providers=data_provider)
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        if self.optimization_config.augmentation_config.random_crop:
            pipeline += RandomCropper(
                crop_height=self.optimization_config.augmentation_config.
                random_crop_height,
                crop_width=self.optimization_config.augmentation_config.
                random_crop_width,
            )
        return pipeline
Example #2
0
    def make_viz_pipeline(self, data_provider: Provider,
                          keras_model: tf.keras.Model) -> Pipeline:
        """Create visualization pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.
            keras_model: A `tf.keras.Model` that can be used for inference.

        Returns:
            A `Pipeline` instance configured to fetch data and run inference to generate
            predictions useful for visualization during training.
        """
        pipeline = self.make_base_pipeline(data_provider=data_provider)
        pipeline += Prefetcher()
        pipeline += Repeater()
        if self.optimization_config.augmentation_config.random_crop:
            pipeline += RandomCropper(
                crop_height=self.optimization_config.augmentation_config.
                random_crop_height,
                crop_width=self.optimization_config.augmentation_config.
                random_crop_width,
            )
        pipeline += KerasModelPredictor(
            keras_model=keras_model,
            model_input_keys="image",
            model_output_keys="predicted_centroid_confidence_maps",
        )
        pipeline += LocalPeakFinder(
            confmaps_stride=self.centroid_confmap_head.output_stride,
            peak_threshold=0.2,
            confmaps_key="predicted_centroid_confidence_maps",
            peaks_key="predicted_centroids",
            peak_vals_key="predicted_centroid_confidences",
            peak_sample_inds_key="predicted_centroid_sample_inds",
            peak_channel_inds_key="predicted_centroid_channel_inds",
        )
        return pipeline
Example #3
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(
                shuffle=True,
                buffer_size=self.optimization_config.shuffle_buffer_size)

        aug_config = self.optimization_config.augmentation_config
        if aug_config.random_flip:
            pipeline += RandomFlipper.from_skeleton(
                self.data_config.labels.skeletons[0],
                horizontal=aug_config.flip_horizontal,
            )
        pipeline += ImgaugAugmenter.from_config(aug_config)
        if aug_config.random_crop:
            pipeline += RandomCropper(
                crop_height=aug_config.random_crop_height,
                crop_width=aug_config.random_crop_width,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        pipeline += MultiConfidenceMapGenerator(
            sigma=self.confmaps_head.sigma,
            output_stride=self.confmaps_head.output_stride,
            centroids=False,
            with_offsets=self.offsets_head is not None,
            offsets_threshold=self.offsets_head.sigma_threshold
            if self.offsets_head is not None else 1.0,
        )
        pipeline += PartAffinityFieldsGenerator(
            sigma=self.pafs_head.sigma,
            output_stride=self.pafs_head.output_stride,
            skeletons=self.data_config.labels.skeletons,
            flatten_channels=True,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline