Example #1
0
def test_multi_confidence_map_generator_centroids(min_labels):
    labels_reader = providers.LabelsReader(min_labels)
    instance_centroid_finder = instance_centroids.InstanceCentroidFinder(
        center_on_anchor_part=True,
        anchor_part_names="A",
        skeletons=labels_reader.labels.skeletons,
    )
    multi_confmap_generator = MultiConfidenceMapGenerator(
        sigma=5, output_stride=2, centroids=True
    )
    ds = labels_reader.make_dataset()
    ds = instance_centroid_finder.transform_dataset(ds)
    ds = multi_confmap_generator.transform_dataset(ds)
    example = next(iter(ds))

    assert example["centroid_confidence_maps"].shape == (192, 192, 1)
    assert example["centroid_confidence_maps"].dtype == tf.float32

    centroids = example["centroids"].numpy() / multi_confmap_generator.output_stride
    centroid_cms = example["centroid_confidence_maps"].numpy()

    np.testing.assert_allclose(
        centroid_cms[int(centroids[0, 1]), int(centroids[0, 0]), :], [0.9811318]
    )
    np.testing.assert_allclose(
        centroid_cms[int(centroids[1, 1]), int(centroids[1, 0]), :], [0.8642299]
    )
Example #2
0
def test_multi_confidence_map_generator(min_labels):
    labels_reader = providers.LabelsReader(min_labels)
    multi_confmap_generator = MultiConfidenceMapGenerator(
        sigma=3, output_stride=2, centroids=False
    )
    ds = labels_reader.make_dataset()
    ds = multi_confmap_generator.transform_dataset(ds)
    example = next(iter(ds))

    assert example["confidence_maps"].shape == (192, 192, 2)
    assert example["confidence_maps"].dtype == tf.float32

    instances = example["instances"].numpy() / multi_confmap_generator.output_stride
    cms = example["confidence_maps"].numpy()

    np.testing.assert_allclose(
        cms[int(instances[0, 0, 1]), int(instances[0, 0, 0]), :], [0.948463, 0.0]
    )
    np.testing.assert_allclose(
        cms[int(instances[1, 0, 1]), int(instances[1, 0, 0]), :], [0.66676116, 0.0]
    )

    np.testing.assert_allclose(
        cms[int(instances[0, 1, 1]), int(instances[0, 1, 0]), :], [0.0, 0.9836702]
    )
    np.testing.assert_allclose(
        cms[int(instances[1, 1, 1]), int(instances[1, 1, 0]), :], [0.0, 0.8815618]
    )
Example #3
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(self.optimization_config.shuffle_buffer_size)

        pipeline += ImgaugAugmenter.from_config(
            self.optimization_config.augmentation_config)
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        pipeline += Resizer.from_config(self.data_config.preprocessing)

        pipeline += MultiConfidenceMapGenerator(
            sigma=self.confmaps_head.sigma,
            output_stride=self.confmaps_head.output_stride,
            centroids=False,
        )
        pipeline += PartAffinityFieldsGenerator(
            sigma=self.pafs_head.sigma,
            output_stride=self.pafs_head.output_stride,
            skeletons=self.data_config.labels.skeletons,
            flatten_channels=True,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(batch_size=self.optimization_config.batch_size,
                                drop_remainder=True)
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(batch_size=self.optimization_config.batch_size,
                                drop_remainder=True)

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline
Example #4
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(
                shuffle=True,
                buffer_size=self.optimization_config.shuffle_buffer_size)
        if self.optimization_config.augmentation_config.random_flip:
            pipeline += RandomFlipper.from_skeleton(
                self.data_config.labels.skeletons[0],
                horizontal=self.optimization_config.augmentation_config.
                flip_horizontal,
            )
        pipeline += ImgaugAugmenter.from_config(
            self.optimization_config.augmentation_config)
        if self.optimization_config.augmentation_config.random_crop:
            pipeline += RandomCropper(
                crop_height=self.optimization_config.augmentation_config.
                random_crop_height,
                crop_width=self.optimization_config.augmentation_config.
                random_crop_width,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        pipeline += InstanceCentroidFinder.from_config(
            self.data_config.instance_cropping,
            skeletons=self.data_config.labels.skeletons,
        )
        pipeline += MultiConfidenceMapGenerator(
            sigma=self.centroid_confmap_head.sigma,
            output_stride=self.centroid_confmap_head.output_stride,
            centroids=True,
            with_offsets=self.offsets_head is not None,
            offsets_threshold=self.offsets_head.sigma_threshold
            if self.offsets_head is not None else 1.0,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline