def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """Create base pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to produce input examples. """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) if self.data_config.preprocessing.resize_and_pad_to_target: pipeline += SizeMatcher.from_config( config=self.data_config.preprocessing, provider=data_provider, ) pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, ) pipeline += InstanceCropper.from_config( self.data_config.instance_cropping) return pipeline
def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """Create base pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to produce input examples. """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) if self.data_config.preprocessing.resize_and_pad_to_target: pipeline += SizeMatcher.from_config( config=self.data_config.preprocessing, provider=data_provider, ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( crop_height=self.optimization_config.augmentation_config. random_crop_height, crop_width=self.optimization_config.augmentation_config. random_crop_width, ) return pipeline
def make_viz_pipeline(self, data_provider: Provider) -> Pipeline: """Create visualization pipeline. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to fetch data and for running inference to generate predictions useful for visualization during training. """ pipeline = Pipeline(data_provider) if self.data_config.preprocessing.resize_and_pad_to_target: pipeline += SizeMatcher.from_config( config=self.data_config.preprocessing, provider=data_provider, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, ) pipeline += InstanceCropper.from_config(self.data_config.instance_cropping) pipeline += Repeater() pipeline += Prefetcher() return pipeline
def test_size_matcher(): # Create some fake data using two different size videos. skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"]) labels = sleap.Labels([ sleap.LabeledFrame( frame_idx=0, video=sleap.Video.from_filename(TEST_SMALL_ROBOT_MP4_FILE, grayscale=True), instances=[ sleap.Instance.from_pointsarray(np.array([[128, 128]]), skeleton=skeleton) ], ), sleap.LabeledFrame( frame_idx=0, video=sleap.Video.from_filename(TEST_H5_FILE, dataset="/box", input_format="channels_first"), instances=[ sleap.Instance.from_pointsarray(np.array([[128, 128]]), skeleton=skeleton) ], ), ]) # Create a loader for those labels. labels_reader = providers.LabelsReader(labels) ds = labels_reader.make_dataset() ds_iter = iter(ds) assert next(ds_iter)["image"].shape == (320, 560, 1) assert next(ds_iter)["image"].shape == (512, 512, 1) def check_padding(image, from_y, to_y, from_x, to_x): assert (image.numpy()[from_y:to_y, from_x:to_x] == 0).all() # Check SizeMatcher when target dims is not strictly larger than actual image dims size_matcher = SizeMatcher(max_image_height=560, max_image_width=560) transform_iter = iter(size_matcher.transform_dataset(ds)) im1 = next(transform_iter)["image"] assert im1.shape == (560, 560, 1) # padding should be on the bottom check_padding(im1, 321, 560, 0, 560) im2 = next(transform_iter)["image"] assert im2.shape == (560, 560, 1) # Variant 2 size_matcher = SizeMatcher(max_image_height=320, max_image_width=560) transform_iter = iter(size_matcher.transform_dataset(ds)) im1 = next(transform_iter)["image"] assert im1.shape == (320, 560, 1) im2 = next(transform_iter)["image"] assert im2.shape == (320, 560, 1) # padding should be on the right check_padding(im2, 0, 320, 321, 560) # Check SizeMatcher when target is 'max' in both dimensions size_matcher = SizeMatcher(max_image_height=512, max_image_width=560) transform_iter = iter(size_matcher.transform_dataset(ds)) im1 = next(transform_iter)["image"] assert im1.shape == (512, 560, 1) # Check padding is on the bottom check_padding(im1, 320, 512, 0, 560) im2 = next(transform_iter)["image"] assert im2.shape == (512, 560, 1) # Check padding is on the right check_padding(im2, 0, 512, 512, 560) # Check SizeMatcher when target is larger in both dimensions size_matcher = SizeMatcher(max_image_height=750, max_image_width=750) transform_iter = iter(size_matcher.transform_dataset(ds)) im1 = next(transform_iter)["image"] assert im1.shape == (750, 750, 1) # Check padding is on the bottom check_padding(im1, 700, 750, 0, 750) im2 = next(transform_iter)["image"] assert im2.shape == (750, 750, 1)
def make_training_pipeline(self, data_provider: Provider) -> Pipeline: """Create full training pipeline. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. Returns: A `Pipeline` instance configured to produce all data keys required for training. Notes: This does not remap keys to model outputs. Use `KeyMapper` to pull out keys with the appropriate format for the instantiated `tf.keras.Model`. """ pipeline = Pipeline(providers=data_provider) if self.optimization_config.preload_data: pipeline += Preloader() if self.optimization_config.online_shuffling: pipeline += Shuffler( shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size) aug_config = self.optimization_config.augmentation_config if aug_config.random_flip: pipeline += RandomFlipper.from_skeleton( self.data_config.labels.skeletons[0], horizontal=aug_config.flip_horizontal, ) pipeline += ImgaugAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( crop_height=aug_config.random_crop_height, crop_width=aug_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) if self.data_config.preprocessing.resize_and_pad_to_target: pipeline += SizeMatcher.from_config( config=self.data_config.preprocessing, provider=data_provider, ) pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += MultiConfidenceMapGenerator( sigma=self.confmaps_head.sigma, output_stride=self.confmaps_head.output_stride, centroids=False, with_offsets=self.offsets_head is not None, offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0, ) pipeline += PartAffinityFieldsGenerator( sigma=self.pafs_head.sigma, output_stride=self.pafs_head.output_stride, skeletons=self.data_config.labels.skeletons, flatten_channels=True, ) if len(data_provider) >= self.optimization_config.batch_size: # Batching before repeating is preferred since it preserves epoch boundaries # such that no sample is repeated within the epoch. But this breaks if there # are fewer samples than the batch size. pipeline += Batcher( batch_size=self.optimization_config.batch_size, drop_remainder=True, unrag=True, ) pipeline += Repeater() else: pipeline += Repeater() pipeline += Batcher( batch_size=self.optimization_config.batch_size, drop_remainder=True, unrag=True, ) if self.optimization_config.prefetch: pipeline += Prefetcher() return pipeline