Ejemplo n.º 1
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(self.optimization_config.shuffle_buffer_size)

        pipeline += ImgaugAugmenter.from_config(
            self.optimization_config.augmentation_config)
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        pipeline += Resizer.from_config(self.data_config.preprocessing)

        pipeline += InstanceCentroidFinder.from_config(
            self.data_config.instance_cropping,
            skeletons=self.data_config.labels.skeletons,
        )
        pipeline += MultiConfidenceMapGenerator(
            sigma=self.centroid_confmap_head.sigma,
            output_stride=self.centroid_confmap_head.output_stride,
            centroids=True,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(batch_size=self.optimization_config.batch_size,
                                drop_remainder=True)
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(batch_size=self.optimization_config.batch_size,
                                drop_remainder=True)

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline
Ejemplo n.º 2
0
    def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create base pipeline with input data only.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce input examples.
        """
        pipeline = Pipeline(providers=data_provider)
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        return pipeline
Ejemplo n.º 3
0
    def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create base pipeline with input data only.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce input examples.
        """
        pipeline = Pipeline(providers=data_provider)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        if self.optimization_config.augmentation_config.random_crop:
            pipeline += RandomCropper(
                crop_height=self.optimization_config.augmentation_config.random_crop_height,
                crop_width=self.optimization_config.augmentation_config.random_crop_width,
            )
        return pipeline
Ejemplo n.º 4
0
    def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create base pipeline with input data only.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce input examples.
        """
        pipeline = Pipeline(providers=data_provider)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        pipeline += InstanceCentroidFinder.from_config(
            self.data_config.instance_cropping,
            skeletons=self.data_config.labels.skeletons,
        )
        pipeline += InstanceCropper.from_config(self.data_config.instance_cropping)
        return pipeline
Ejemplo n.º 5
0
    def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
        """Create full training pipeline.

        Args:
            data_provider: A `Provider` that generates data examples, typically a
                `LabelsReader` instance.

        Returns:
            A `Pipeline` instance configured to produce all data keys required for
            training.

        Notes:
            This does not remap keys to model outputs. Use `KeyMapper` to pull out keys
            with the appropriate format for the instantiated `tf.keras.Model`.
        """
        pipeline = Pipeline(providers=data_provider)

        if self.optimization_config.preload_data:
            pipeline += Preloader()

        if self.optimization_config.online_shuffling:
            pipeline += Shuffler(
                shuffle=True,
                buffer_size=self.optimization_config.shuffle_buffer_size)

        aug_config = self.optimization_config.augmentation_config
        if aug_config.random_flip:
            pipeline += RandomFlipper.from_skeleton(
                self.data_config.labels.skeletons[0],
                horizontal=aug_config.flip_horizontal,
            )
        pipeline += ImgaugAugmenter.from_config(aug_config)
        if aug_config.random_crop:
            pipeline += RandomCropper(
                crop_height=aug_config.random_crop_height,
                crop_width=aug_config.random_crop_width,
            )
        pipeline += Normalizer.from_config(self.data_config.preprocessing)
        if self.data_config.preprocessing.resize_and_pad_to_target:
            pipeline += SizeMatcher.from_config(
                config=self.data_config.preprocessing,
                provider=data_provider,
            )
        pipeline += Resizer.from_config(self.data_config.preprocessing)
        pipeline += MultiConfidenceMapGenerator(
            sigma=self.confmaps_head.sigma,
            output_stride=self.confmaps_head.output_stride,
            centroids=False,
            with_offsets=self.offsets_head is not None,
            offsets_threshold=self.offsets_head.sigma_threshold
            if self.offsets_head is not None else 1.0,
        )
        pipeline += PartAffinityFieldsGenerator(
            sigma=self.pafs_head.sigma,
            output_stride=self.pafs_head.output_stride,
            skeletons=self.data_config.labels.skeletons,
            flatten_channels=True,
        )

        if len(data_provider) >= self.optimization_config.batch_size:
            # Batching before repeating is preferred since it preserves epoch boundaries
            # such that no sample is repeated within the epoch. But this breaks if there
            # are fewer samples than the batch size.
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )
            pipeline += Repeater()

        else:
            pipeline += Repeater()
            pipeline += Batcher(
                batch_size=self.optimization_config.batch_size,
                drop_remainder=True,
                unrag=True,
            )

        if self.optimization_config.prefetch:
            pipeline += Prefetcher()

        return pipeline
Ejemplo n.º 6
0
def make_datagen_results(reader: LabelsReader, cfg: TrainingJobConfig):
    cfg = copy.deepcopy(cfg)
    output_keys = dict()

    if cfg.data.preprocessing.pad_to_stride is None:
        cfg.data.preprocessing.pad_to_stride = (
            cfg.model.backbone.which_oneof().max_stride)

    pipeline = pipelines.Pipeline(reader)
    pipeline += Resizer.from_config(cfg.data.preprocessing)

    head_config = cfg.model.heads.which_oneof()
    if isinstance(head_config, CentroidsHeadConfig):
        pipeline += pipelines.InstanceCentroidFinder.from_config(
            cfg.data.instance_cropping, skeletons=reader.labels.skeletons)
        pipeline += pipelines.MultiConfidenceMapGenerator(
            sigma=cfg.model.heads.centroid.sigma,
            output_stride=cfg.model.heads.centroid.output_stride,
            centroids=True,
        )

        output_keys["image"] = "image"
        output_keys["confmap"] = "centroid_confidence_maps"

    elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig):
        if cfg.data.instance_cropping.crop_size is None:
            cfg.data.instance_cropping.crop_size = find_instance_crop_size(
                labels=reader.labels,
                padding=cfg.data.instance_cropping.crop_size_detection_padding,
                maximum_stride=cfg.model.backbone.which_oneof().max_stride,
            )

        pipeline += pipelines.InstanceCentroidFinder.from_config(
            cfg.data.instance_cropping, skeletons=reader.labels.skeletons)
        pipeline += pipelines.InstanceCropper.from_config(
            cfg.data.instance_cropping)
        pipeline += pipelines.InstanceConfidenceMapGenerator(
            sigma=cfg.model.heads.centered_instance.sigma,
            output_stride=cfg.model.heads.centered_instance.output_stride,
        )

        output_keys["image"] = "instance_image"
        output_keys["confmap"] = "instance_confidence_maps"

    elif isinstance(head_config, MultiInstanceConfig):
        output_keys["image"] = "image"
        output_keys["confmap"] = "confidence_maps"
        output_keys["paf"] = "confidence_maps"

        pipeline += pipelines.MultiConfidenceMapGenerator(
            sigma=cfg.model.heads.multi_instance.confmaps.sigma,
            output_stride=cfg.model.heads.multi_instance.confmaps.
            output_stride,
        )
        pipeline += pipelines.PartAffinityFieldsGenerator(
            sigma=cfg.model.heads.multi_instance.pafs.sigma,
            output_stride=cfg.model.heads.multi_instance.pafs.output_stride,
            skeletons=reader.labels.skeletons,
        )

    ds = pipeline.make_dataset()

    output_lists = defaultdict(list)
    i = 0
    for example in ds:
        for key, from_key in output_keys.items():
            output_lists[key].append(example[from_key])
        i += 1
        if i == MAX_FRAMES_TO_PREVIEW:
            break

    outputs = dict()
    for key in output_lists.keys():
        outputs[key] = np.stack(output_lists[key])

    return outputs
Ejemplo n.º 7
0
def make_datagen_results(reader: LabelsReader,
                         cfg: TrainingJobConfig) -> np.ndarray:
    """
    Gets (subset of) raw images used for training.

    TODO: Refactor so we can get this data without digging into details of the
      the specific pipelines (e.g., key for confmaps depends on head type).
    """
    cfg = copy.deepcopy(cfg)
    output_keys = dict()

    if cfg.data.preprocessing.pad_to_stride is None:
        cfg.data.preprocessing.pad_to_stride = (
            cfg.model.backbone.which_oneof().max_stride)

    pipeline = pipelines.Pipeline(reader)
    pipeline += Resizer.from_config(cfg.data.preprocessing)

    head_config = cfg.model.heads.which_oneof()
    if isinstance(head_config, CentroidsHeadConfig):
        pipeline += pipelines.InstanceCentroidFinder.from_config(
            cfg.data.instance_cropping, skeletons=reader.labels.skeletons)
        pipeline += pipelines.MultiConfidenceMapGenerator(
            sigma=cfg.model.heads.centroid.sigma,
            output_stride=cfg.model.heads.centroid.output_stride,
            centroids=True,
        )

        output_keys["image"] = "image"
        output_keys["confmap"] = "centroid_confidence_maps"

    elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig):
        if cfg.data.instance_cropping.crop_size is None:
            cfg.data.instance_cropping.crop_size = find_instance_crop_size(
                labels=reader.labels,
                padding=cfg.data.instance_cropping.crop_size_detection_padding,
                maximum_stride=cfg.model.backbone.which_oneof().max_stride,
            )

        pipeline += pipelines.InstanceCentroidFinder.from_config(
            cfg.data.instance_cropping, skeletons=reader.labels.skeletons)
        pipeline += pipelines.InstanceCropper.from_config(
            cfg.data.instance_cropping)
        pipeline += pipelines.InstanceConfidenceMapGenerator(
            sigma=cfg.model.heads.centered_instance.sigma,
            output_stride=cfg.model.heads.centered_instance.output_stride,
        )

        output_keys["image"] = "instance_image"
        output_keys["confmap"] = "instance_confidence_maps"

    elif isinstance(head_config, MultiInstanceConfig):
        output_keys["image"] = "image"
        output_keys["confmap"] = "confidence_maps"
        output_keys["paf"] = "part_affinity_fields"

        pipeline += pipelines.MultiConfidenceMapGenerator(
            sigma=cfg.model.heads.multi_instance.confmaps.sigma,
            output_stride=cfg.model.heads.multi_instance.confmaps.
            output_stride,
        )
        pipeline += pipelines.PartAffinityFieldsGenerator(
            sigma=cfg.model.heads.multi_instance.pafs.sigma,
            output_stride=cfg.model.heads.multi_instance.pafs.output_stride,
            skeletons=reader.labels.skeletons,
            flatten_channels=True,
        )

    ds = pipeline.make_dataset()

    output_lists = defaultdict(list)
    i = 0
    for example in ds:
        for key, from_key in output_keys.items():
            output_lists[key].append(example[from_key])
        i += 1
        if i == MAX_FRAMES_TO_PREVIEW:
            break

    outputs = dict()
    for key in output_lists.keys():
        outputs[key] = np.stack(output_lists[key])

    return outputs