Exemplo n.º 1
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(self.optimization_config.shuffle_buffer_size)

        pipeline += ImgaugAugmenter.from_config(
            self.optimization_config.augmentation_config)
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        pipeline += Resizer.from_config(self.data_config.preprocessing)

        pipeline += InstanceCentroidFinder.from_config(
            self.data_config.instance_cropping,
            skeletons=self.data_config.labels.skeletons,
        )
        pipeline += InstanceCropper.from_config(
            self.data_config.instance_cropping)
        pipeline += InstanceConfidenceMapGenerator(
            sigma=self.instance_confmap_head.sigma,
            output_stride=self.instance_confmap_head.output_stride,
            all_instances=False,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(batch_size=self.optimization_config.batch_size,
                                drop_remainder=True)
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(batch_size=self.optimization_config.batch_size,
                                drop_remainder=True)

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline
Exemplo n.º 2
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(
                shuffle=True,
                buffer_size=self.optimization_config.shuffle_buffer_size)

        aug_config = self.optimization_config.augmentation_config
        if aug_config.random_flip:
            pipeline += RandomFlipper.from_skeleton(
                self.data_config.labels.skeletons[0],
                horizontal=aug_config.flip_horizontal,
            )
        pipeline += ImgaugAugmenter.from_config(aug_config)
        if aug_config.random_crop:
            pipeline += RandomCropper(
                crop_height=aug_config.random_crop_height,
                crop_width=aug_config.random_crop_width,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        pipeline += MultiConfidenceMapGenerator(
            sigma=self.confmaps_head.sigma,
            output_stride=self.confmaps_head.output_stride,
            centroids=False,
            with_offsets=self.offsets_head is not None,
            offsets_threshold=self.offsets_head.sigma_threshold
            if self.offsets_head is not None else 1.0,
        )
        pipeline += PartAffinityFieldsGenerator(
            sigma=self.pafs_head.sigma,
            output_stride=self.pafs_head.output_stride,
            skeletons=self.data_config.labels.skeletons,
            flatten_channels=True,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline