Esempio n. 1
0
    def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create base pipeline with input data only.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce input examples.
        """
        pipeline = Pipeline(providers=data_provider)
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        pipeline += InstanceCentroidFinder.from_config(
            self.data_config.instance_cropping,
            skeletons=self.data_config.labels.skeletons,
        )
        pipeline += InstanceCropper.from_config(
            self.data_config.instance_cropping)
        return pipeline
Esempio n. 2
0
    def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create base pipeline with input data only.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce input examples.
        """
        pipeline = Pipeline(providers=data_provider)
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        if self.optimization_config.augmentation_config.random_crop:
            pipeline += RandomCropper(
                crop_height=self.optimization_config.augmentation_config.
                random_crop_height,
                crop_width=self.optimization_config.augmentation_config.
                random_crop_width,
            )
        return pipeline
Esempio n. 3
0
    def make_viz_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create visualization pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to fetch data and for running inference to
            generate predictions useful for visualization during training.
        """
        pipeline = Pipeline(data_provider)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        pipeline += InstanceCentroidFinder.from_config(
            self.data_config.instance_cropping,
            skeletons=self.data_config.labels.skeletons,
        )
        pipeline += InstanceCropper.from_config(self.data_config.instance_cropping)
        pipeline += Repeater()
        pipeline += Prefetcher()
        return pipeline
Esempio n. 4
0
def test_size_matcher():
    # Create some fake data using two different size videos.
    skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"])
    labels = sleap.Labels([
        sleap.LabeledFrame(
            frame_idx=0,
            video=sleap.Video.from_filename(TEST_SMALL_ROBOT_MP4_FILE,
                                            grayscale=True),
            instances=[
                sleap.Instance.from_pointsarray(np.array([[128, 128]]),
                                                skeleton=skeleton)
            ],
        ),
        sleap.LabeledFrame(
            frame_idx=0,
            video=sleap.Video.from_filename(TEST_H5_FILE,
                                            dataset="/box",
                                            input_format="channels_first"),
            instances=[
                sleap.Instance.from_pointsarray(np.array([[128, 128]]),
                                                skeleton=skeleton)
            ],
        ),
    ])

    # Create a loader for those labels.
    labels_reader = providers.LabelsReader(labels)
    ds = labels_reader.make_dataset()
    ds_iter = iter(ds)
    assert next(ds_iter)["image"].shape == (320, 560, 1)
    assert next(ds_iter)["image"].shape == (512, 512, 1)

    def check_padding(image, from_y, to_y, from_x, to_x):
        assert (image.numpy()[from_y:to_y, from_x:to_x] == 0).all()

    # Check SizeMatcher when target dims is not strictly larger than actual image dims
    size_matcher = SizeMatcher(max_image_height=560, max_image_width=560)
    transform_iter = iter(size_matcher.transform_dataset(ds))
    im1 = next(transform_iter)["image"]
    assert im1.shape == (560, 560, 1)
    # padding should be on the bottom
    check_padding(im1, 321, 560, 0, 560)
    im2 = next(transform_iter)["image"]
    assert im2.shape == (560, 560, 1)

    # Variant 2
    size_matcher = SizeMatcher(max_image_height=320, max_image_width=560)
    transform_iter = iter(size_matcher.transform_dataset(ds))
    im1 = next(transform_iter)["image"]
    assert im1.shape == (320, 560, 1)
    im2 = next(transform_iter)["image"]
    assert im2.shape == (320, 560, 1)
    # padding should be on the right
    check_padding(im2, 0, 320, 321, 560)

    # Check SizeMatcher when target is 'max' in both dimensions
    size_matcher = SizeMatcher(max_image_height=512, max_image_width=560)
    transform_iter = iter(size_matcher.transform_dataset(ds))
    im1 = next(transform_iter)["image"]
    assert im1.shape == (512, 560, 1)
    # Check padding is on the bottom
    check_padding(im1, 320, 512, 0, 560)
    im2 = next(transform_iter)["image"]
    assert im2.shape == (512, 560, 1)
    # Check padding is on the right
    check_padding(im2, 0, 512, 512, 560)

    # Check SizeMatcher when target is larger in both dimensions
    size_matcher = SizeMatcher(max_image_height=750, max_image_width=750)
    transform_iter = iter(size_matcher.transform_dataset(ds))
    im1 = next(transform_iter)["image"]
    assert im1.shape == (750, 750, 1)
    # Check padding is on the bottom
    check_padding(im1, 700, 750, 0, 750)
    im2 = next(transform_iter)["image"]
    assert im2.shape == (750, 750, 1)
Esempio n. 5
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(
                shuffle=True,
                buffer_size=self.optimization_config.shuffle_buffer_size)

        aug_config = self.optimization_config.augmentation_config
        if aug_config.random_flip:
            pipeline += RandomFlipper.from_skeleton(
                self.data_config.labels.skeletons[0],
                horizontal=aug_config.flip_horizontal,
            )
        pipeline += ImgaugAugmenter.from_config(aug_config)
        if aug_config.random_crop:
            pipeline += RandomCropper(
                crop_height=aug_config.random_crop_height,
                crop_width=aug_config.random_crop_width,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        pipeline += MultiConfidenceMapGenerator(
            sigma=self.confmaps_head.sigma,
            output_stride=self.confmaps_head.output_stride,
            centroids=False,
            with_offsets=self.offsets_head is not None,
            offsets_threshold=self.offsets_head.sigma_threshold
            if self.offsets_head is not None else 1.0,
        )
        pipeline += PartAffinityFieldsGenerator(
            sigma=self.pafs_head.sigma,
            output_stride=self.pafs_head.output_stride,
            skeletons=self.data_config.labels.skeletons,
            flatten_channels=True,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline