Exemplo n.º 1
0
    def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:

        pipeline = Pipeline()
        if data_provider is not None:
            pipeline.providers = [data_provider]

        # Infer colorspace preprocessing if not explicit.
        if not (
            self.confmap_config.data.preprocessing.ensure_rgb
            or self.confmap_config.data.preprocessing.ensure_grayscale
        ):
            if self.confmap_model.keras_model.inputs[0].shape[-1] == 1:
                self.confmap_config.data.preprocessing.ensure_grayscale = True
            else:
                self.confmap_config.data.preprocessing.ensure_rgb = True

        pipeline += Normalizer.from_config(self.confmap_config.data.preprocessing)
        pipeline += Resizer.from_config(
            self.confmap_config.data.preprocessing, points_key=None
        )

        pipeline += Prefetcher()

        pipeline += KerasModelPredictor(
            keras_model=self.confmap_model.keras_model,
            model_input_keys="image",
            model_output_keys="predicted_instance_confidence_maps",
        )
        pipeline += GlobalPeakFinder(
            confmaps_key="predicted_instance_confidence_maps",
            peaks_key="predicted_instance",
            peak_vals_key="predicted_instance_confidences",
            confmaps_stride=self.confmap_model.heads[0].output_stride,
            peak_threshold=self.peak_threshold,
            integral=self.integral_refinement,
            integral_patch_size=self.integral_patch_size,
        )

        pipeline += KeyFilter(
            keep_keys=[
                "scale",
                "video_ind",
                "frame_ind",
                "predicted_instance",
                "predicted_instance_confidences",
            ]
        )

        pipeline += PointsRescaler(
            points_key="predicted_instance", scale_key="scale", invert=True
        )

        self.pipeline = pipeline

        return pipeline
Exemplo n.º 2
0
    def make_pipeline(self):
        pipeline = Pipeline()

        pipeline += Normalizer.from_config(self.config.data.preprocessing)
        pipeline += Resizer.from_config(
            self.config.data.preprocessing, keep_full_image=False, points_key=None
        )

        pipeline += KerasModelPredictor(
            keras_model=self.model.keras_model,
            model_input_keys="image",
            model_output_keys=self.head_specific_output_keys(),
        )

        self.pipeline = pipeline
Exemplo n.º 3
0
    def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
        pipeline = Pipeline()
        if data_provider is not None:
            pipeline.providers = [data_provider]

        # Infer colorspace preprocessing if not explicit.
        if not (
            self.bottomup_config.data.preprocessing.ensure_rgb
            or self.bottomup_config.data.preprocessing.ensure_grayscale
        ):
            if self.bottomup_model.keras_model.inputs[0].shape[-1] == 1:
                self.bottomup_config.data.preprocessing.ensure_grayscale = True
            else:
                self.bottomup_config.data.preprocessing.ensure_rgb = True

        pipeline += Normalizer.from_config(self.bottomup_config.data.preprocessing)
        pipeline += Resizer.from_config(
            self.bottomup_config.data.preprocessing,
            keep_full_image=False,
            points_key=None,
        )

        pipeline += Prefetcher()

        pipeline += KerasModelPredictor(
            keras_model=self.bottomup_model.keras_model,
            model_input_keys="image",
            model_output_keys=[
                "predicted_confidence_maps",
                "predicted_part_affinity_fields",
            ],
        )
        pipeline += LocalPeakFinder(
            confmaps_stride=self.bottomup_model.heads[0].output_stride,
            peak_threshold=self.peak_threshold,
            confmaps_key="predicted_confidence_maps",
            peaks_key="predicted_peaks",
            peak_vals_key="predicted_peak_confidences",
            peak_sample_inds_key="predicted_peak_sample_inds",
            peak_channel_inds_key="predicted_peak_channel_inds",
            keep_confmaps=False,
        )

        pipeline += LambdaFilter(filter_fn=lambda ex: len(ex["predicted_peaks"]) > 0)

        pipeline += PartAffinityFieldInstanceGrouper.from_config(
            self.bottomup_config.model.heads.multi_instance,
            max_edge_length=128,
            min_edge_score=0.05,
            n_points=10,
            min_instance_peaks=0,
            peaks_key="predicted_peaks",
            peak_scores_key="predicted_peak_confidences",
            channel_inds_key="predicted_peak_channel_inds",
            pafs_key="predicted_part_affinity_fields",
            predicted_instances_key="predicted_instances",
            predicted_peak_scores_key="predicted_peak_scores",
            predicted_instance_scores_key="predicted_instance_scores",
            keep_pafs=False,
        )

        keep_keys = [
            "scale",
            "video_ind",
            "frame_ind",
            "predicted_instances",
            "predicted_peak_scores",
            "predicted_instance_scores",
        ]

        if self.tracker and self.tracker.uses_image:
            keep_keys.append("image")

        pipeline += KeyFilter(keep_keys=keep_keys)

        pipeline += PointsRescaler(
            points_key="predicted_instances", scale_key="scale", invert=True
        )

        self.pipeline = pipeline

        return pipeline
Exemplo n.º 4
0
    def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:

        keep_original_image = self.tracker and self.tracker.uses_image

        pipeline = Pipeline()
        if data_provider is not None:
            pipeline.providers = [data_provider]

        pipeline += Prefetcher()

        pipeline += KeyRenamer(
            old_key_names=["image", "scale"],
            new_key_names=["full_image", "full_image_scale"],
            drop_old=False,
        )

        if keep_original_image:
            pipeline += KeyRenamer(
                old_key_names=["image", "scale"],
                new_key_names=["original_image", "original_image_scale"],
                drop_old=False,
            )
            pipeline += KeyDeviceMover(["original_image"])

        if self.confmap_config is not None:
            # Infer colorspace preprocessing if not explicit.
            if not (
                self.confmap_config.data.preprocessing.ensure_rgb
                or self.confmap_config.data.preprocessing.ensure_grayscale
            ):
                if self.confmap_model.keras_model.inputs[0].shape[-1] == 1:
                    self.confmap_config.data.preprocessing.ensure_grayscale = True
                else:
                    self.confmap_config.data.preprocessing.ensure_rgb = True

            pipeline += Normalizer.from_config(
                self.confmap_config.data.preprocessing, image_key="full_image"
            )

            points_key = "instances" if self.centroid_model is None else None
            pipeline += Resizer.from_config(
                self.confmap_config.data.preprocessing,
                points_key=points_key,
                image_key="full_image",
                scale_key="full_image_scale",
            )

        if self.centroid_model is not None:
            # Infer colorspace preprocessing if not explicit.
            if not (
                self.centroid_config.data.preprocessing.ensure_rgb
                or self.centroid_config.data.preprocessing.ensure_grayscale
            ):
                if self.centroid_model.keras_model.inputs[0].shape[-1] == 1:
                    self.centroid_config.data.preprocessing.ensure_grayscale = True
                else:
                    self.centroid_config.data.preprocessing.ensure_rgb = True

            pipeline += Normalizer.from_config(
                self.centroid_config.data.preprocessing, image_key="image"
            )
            pipeline += Resizer.from_config(
                self.centroid_config.data.preprocessing, points_key=None
            )

            # Predict centroids using model.
            pipeline += KerasModelPredictor(
                keras_model=self.centroid_model.keras_model,
                model_input_keys="image",
                model_output_keys="predicted_centroid_confidence_maps",
            )

            pipeline += LocalPeakFinder(
                confmaps_stride=self.centroid_model.heads[0].output_stride,
                peak_threshold=self.peak_threshold,
                confmaps_key="predicted_centroid_confidence_maps",
                peaks_key="predicted_centroids",
                peak_vals_key="predicted_centroid_confidences",
                peak_sample_inds_key="predicted_centroid_sample_inds",
                peak_channel_inds_key="predicted_centroid_channel_inds",
                keep_confmaps=False,
            )

            pipeline += LambdaFilter(
                filter_fn=lambda ex: len(ex["predicted_centroids"]) > 0
            )

            if self.confmap_config is not None:
                crop_size = self.confmap_config.data.instance_cropping.crop_size
            else:
                crop_size = sleap.nn.data.instance_cropping.find_instance_crop_size(
                    data_provider.labels
                )

            pipeline += PredictedInstanceCropper(
                crop_width=crop_size,
                crop_height=crop_size,
                centroids_key="predicted_centroids",
                centroid_confidences_key="predicted_centroid_confidences",
                full_image_key="full_image",
                full_image_scale_key="full_image_scale",
                keep_instances_gt=self.confmap_model is None,
                other_keys_to_keep=["original_image"] if keep_original_image else None,
            )
            if keep_original_image:
                pipeline += KeyDeviceMover(["original_image"])

        else:
            # Generate ground truth centroids and crops.
            anchor_part = self.confmap_config.data.instance_cropping.center_on_part
            pipeline += InstanceCentroidFinder(
                center_on_anchor_part=anchor_part is not None,
                anchor_part_names=anchor_part,
                skeletons=data_provider.labels.skeletons,
            )
            pipeline += KeyRenamer(
                old_key_names=["full_image", "full_image_scale"],
                new_key_names=["image", "scale"],
                drop_old=True,
            )
            pipeline += InstanceCropper(
                crop_width=self.confmap_config.data.instance_cropping.crop_size,
                crop_height=self.confmap_config.data.instance_cropping.crop_size,
                mock_centroid_confidence=True,
            )

        if self.confmap_model is not None:
            # Predict confidence maps using model.
            if self.batch_size > 1:
                pipeline += sleap.nn.data.pipelines.Batcher(
                    batch_size=self.batch_size, drop_remainder=False
                )
            pipeline += KerasModelPredictor(
                keras_model=self.confmap_model.keras_model,
                model_input_keys="instance_image",
                model_output_keys="predicted_instance_confidence_maps",
            )
            if self.batch_size > 1:
                pipeline += sleap.nn.data.pipelines.Unbatcher()
            pipeline += GlobalPeakFinder(
                confmaps_key="predicted_instance_confidence_maps",
                peaks_key="predicted_center_instance_points",
                confmaps_stride=self.confmap_model.heads[0].output_stride,
                peak_threshold=self.peak_threshold,
                integral=self.integral_refinement,
                integral_patch_size=self.integral_patch_size,
                keep_confmaps=False,
            )

        else:
            # Generate ground truth instance points.
            pipeline += MockGlobalPeakFinder(
                all_peaks_in_key="instances",
                peaks_out_key="predicted_center_instance_points",
                peak_vals_key="predicted_center_instance_confidences",
                keep_confmaps=False,
            )

        keep_keys = [
            "bbox",
            "center_instance_ind",
            "centroid",
            "centroid_confidence",
            "scale",
            "video_ind",
            "frame_ind",
            "center_instance_ind",
            "predicted_center_instance_points",
            "predicted_center_instance_confidences",
        ]

        if keep_original_image:
            keep_keys.append("original_image")

        pipeline += KeyFilter(keep_keys=keep_keys)

        pipeline += PredictedCenterInstanceNormalizer(
            centroid_key="centroid",
            centroid_confidence_key="centroid_confidence",
            peaks_key="predicted_center_instance_points",
            peak_confidences_key="predicted_center_instance_confidences",
            new_centroid_key="predicted_centroid",
            new_centroid_confidence_key="predicted_centroid_confidence",
            new_peaks_key="predicted_instance",
            new_peak_confidences_key="predicted_instance_confidences",
        )

        self.pipeline = pipeline

        return pipeline