def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """Create base pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to produce input examples. """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) if self.data_config.preprocessing.resize_and_pad_to_target: pipeline += SizeMatcher.from_config( config=self.data_config.preprocessing, provider=data_provider, ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( crop_height=self.optimization_config.augmentation_config. random_crop_height, crop_width=self.optimization_config.augmentation_config. random_crop_width, ) return pipeline
def make_viz_pipeline(self, data_provider: Provider, keras_model: tf.keras.Model) -> Pipeline: """Create visualization pipeline. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. keras_model: A `tf.keras.Model` that can be used for inference. Returns: A `Pipeline` instance configured to fetch data and run inference to generate predictions useful for visualization during training. """ pipeline = self.make_base_pipeline(data_provider=data_provider) pipeline += Prefetcher() pipeline += Repeater() if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( crop_height=self.optimization_config.augmentation_config. random_crop_height, crop_width=self.optimization_config.augmentation_config. random_crop_width, ) pipeline += KerasModelPredictor( keras_model=keras_model, model_input_keys="image", model_output_keys="predicted_centroid_confidence_maps", ) pipeline += LocalPeakFinder( confmaps_stride=self.centroid_confmap_head.output_stride, peak_threshold=0.2, confmaps_key="predicted_centroid_confidence_maps", peaks_key="predicted_centroids", peak_vals_key="predicted_centroid_confidences", peak_sample_inds_key="predicted_centroid_sample_inds", peak_channel_inds_key="predicted_centroid_channel_inds", ) return pipeline
def make_training_pipeline(self, data_provider: Provider) -> Pipeline: """Create full training pipeline. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to produce all data keys required for training. Notes: This does not remap keys to model outputs. Use `KeyMapper` to pull out keys with the appropriate format for the instantiated `tf.keras.Model`. """ pipeline = Pipeline(providers=data_provider) if self.optimization_config.preload_data: pipeline += Preloader() if self.optimization_config.online_shuffling: pipeline += Shuffler( shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size) aug_config = self.optimization_config.augmentation_config if aug_config.random_flip: pipeline += RandomFlipper.from_skeleton( self.data_config.labels.skeletons[0], horizontal=aug_config.flip_horizontal, ) pipeline += ImgaugAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( crop_height=aug_config.random_crop_height, crop_width=aug_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) if self.data_config.preprocessing.resize_and_pad_to_target: pipeline += SizeMatcher.from_config( config=self.data_config.preprocessing, provider=data_provider, ) pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += MultiConfidenceMapGenerator( sigma=self.confmaps_head.sigma, output_stride=self.confmaps_head.output_stride, centroids=False, with_offsets=self.offsets_head is not None, offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0, ) pipeline += PartAffinityFieldsGenerator( sigma=self.pafs_head.sigma, output_stride=self.pafs_head.output_stride, skeletons=self.data_config.labels.skeletons, flatten_channels=True, ) if len(data_provider) >= self.optimization_config.batch_size: # Batching before repeating is preferred since it preserves epoch boundaries # such that no sample is repeated within the epoch. But this breaks if there # are fewer samples than the batch size. pipeline += Batcher( batch_size=self.optimization_config.batch_size, drop_remainder=True, unrag=True, ) pipeline += Repeater() else: pipeline += Repeater() pipeline += Batcher( batch_size=self.optimization_config.batch_size, drop_remainder=True, unrag=True, ) if self.optimization_config.prefetch: pipeline += Prefetcher() return pipeline