def make_training_pipeline(self, data_provider: Provider) -> Pipeline: """Create full training pipeline. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to produce all data keys required for training. Notes: This does not remap keys to model outputs. Use `KeyMapper` to pull out keys with the appropriate format for the instantiated `tf.keras.Model`. """ pipeline = Pipeline(providers=data_provider) if self.optimization_config.preload_data: pipeline += Preloader() if self.optimization_config.online_shuffling: pipeline += Shuffler(self.optimization_config.shuffle_buffer_size) pipeline += ImgaugAugmenter.from_config( self.optimization_config.augmentation_config) pipeline += Normalizer.from_config(self.data_config.preprocessing) pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, ) pipeline += InstanceCropper.from_config( self.data_config.instance_cropping) pipeline += InstanceConfidenceMapGenerator( sigma=self.instance_confmap_head.sigma, output_stride=self.instance_confmap_head.output_stride, all_instances=False, ) if len(data_provider) >= self.optimization_config.batch_size: # Batching before repeating is preferred since it preserves epoch boundaries # such that no sample is repeated within the epoch. But this breaks if there # are fewer samples than the batch size. pipeline += Batcher(batch_size=self.optimization_config.batch_size, drop_remainder=True) pipeline += Repeater() else: pipeline += Repeater() pipeline += Batcher(batch_size=self.optimization_config.batch_size, drop_remainder=True) if self.optimization_config.prefetch: pipeline += Prefetcher() return pipeline
def make_training_pipeline(self, data_provider: Provider) -> Pipeline: """Create full training pipeline. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to produce all data keys required for training. Notes: This does not remap keys to model outputs. Use `KeyMapper` to pull out keys with the appropriate format for the instantiated `tf.keras.Model`. """ pipeline = Pipeline(providers=data_provider) if self.optimization_config.preload_data: pipeline += Preloader() if self.optimization_config.online_shuffling: pipeline += Shuffler( shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size) aug_config = self.optimization_config.augmentation_config if aug_config.random_flip: pipeline += RandomFlipper.from_skeleton( self.data_config.labels.skeletons[0], horizontal=aug_config.flip_horizontal, ) pipeline += ImgaugAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( crop_height=aug_config.random_crop_height, crop_width=aug_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) if self.data_config.preprocessing.resize_and_pad_to_target: pipeline += SizeMatcher.from_config( config=self.data_config.preprocessing, provider=data_provider, ) pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += MultiConfidenceMapGenerator( sigma=self.confmaps_head.sigma, output_stride=self.confmaps_head.output_stride, centroids=False, with_offsets=self.offsets_head is not None, offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0, ) pipeline += PartAffinityFieldsGenerator( sigma=self.pafs_head.sigma, output_stride=self.pafs_head.output_stride, skeletons=self.data_config.labels.skeletons, flatten_channels=True, ) if len(data_provider) >= self.optimization_config.batch_size: # Batching before repeating is preferred since it preserves epoch boundaries # such that no sample is repeated within the epoch. But this breaks if there # are fewer samples than the batch size. pipeline += Batcher( batch_size=self.optimization_config.batch_size, drop_remainder=True, unrag=True, ) pipeline += Repeater() else: pipeline += Repeater() pipeline += Batcher( batch_size=self.optimization_config.batch_size, drop_remainder=True, unrag=True, ) if self.optimization_config.prefetch: pipeline += Prefetcher() return pipeline