Esempio n. 1
0
        def f(batched_inputs, c2_inputs, c2_results):
            image_sizes = [[int(im[0]), int(im[1])] for im in c2_inputs["im_info"]]
            detector_results = assemble_rcnn_outputs_by_name(
                image_sizes, c2_results, force_mask_on=True
            )
            sem_seg_results = c2_results["sem_seg"]

            # copied from meta_arch/panoptic_fpn.py ...
            processed_results = []
            for sem_seg_result, detector_result, input_per_image, image_size in zip(
                sem_seg_results, detector_results, batched_inputs, image_sizes
            ):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
                detector_r = detector_postprocess(detector_result, height, width)

                processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r})

                if combine_on:
                    panoptic_r = combine_semantic_and_instance_outputs(
                        detector_r,
                        sem_seg_r.argmax(dim=0),
                        combine_overlap_threshold,
                        combine_stuff_area_limit,
                        combine_instances_confidence_threshold,
                    )
                    processed_results[-1]["panoptic_seg"] = panoptic_r
            return processed_results
Esempio n. 2
0
    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.

        For now, each item in the list is a dict that contains:
            image: Tensor, image in (C, H, W) format.
            instances: Instances
            sem_seg: semantic segmentation ground truth.
            Other information that's included in the original dicts, such as:
                "height", "width" (int): the output resolution of the model, used in inference.
                    See :meth:`postprocess` for details.

        Returns:
            list[dict]: each dict is the results for one image. The dict
                contains the following keys:
                "instances": see :meth:`GeneralizedRCNN.forward` for its format.
                "sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
                "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`.
                    See the return value of
                    :func:`combine_semantic_and_instance_outputs` for its format.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images,
                                        self.backbone.size_divisibility)
        features = self.backbone(images.tensor)

        if self.combine_on:
            if "sem_seg" in batched_inputs[0]:
                gt_sem = [x["sem_seg"].to(self.device) for x in batched_inputs]
                gt_sem = ImageList.from_tensors(
                    gt_sem, self.backbone.size_divisibility,
                    self.panoptic_module.ignore_value).tensor
            else:
                gt_sem = None
            sem_seg_results, sem_seg_losses = self.panoptic_module(
                features, gt_sem)

        if "basis_sem" in batched_inputs[0]:
            basis_sem = [
                x["basis_sem"].to(self.device) for x in batched_inputs
            ]
            basis_sem = ImageList.from_tensors(basis_sem,
                                               self.backbone.size_divisibility,
                                               0).tensor
        else:
            basis_sem = None
        basis_out, basis_losses = self.basis_module(features, basis_sem)

        if "instances" in batched_inputs[0]:
            gt_instances = [
                x["instances"].to(self.device) for x in batched_inputs
            ]
        else:
            gt_instances = None
        proposals, proposal_losses = self.proposal_generator(
            images, features, gt_instances, self.top_layer)
        detector_results, detector_losses = self.blender(
            basis_out["bases"], proposals, gt_instances)

        if self.training:
            losses = {}
            losses.update(basis_losses)
            losses.update({
                k: v * self.instance_loss_weight
                for k, v in detector_losses.items()
            })
            losses.update(proposal_losses)
            if self.combine_on:
                losses.update(sem_seg_losses)
            return losses

        processed_results = []
        for i, (detector_result, input_per_image, image_size) in enumerate(
                zip(detector_results, batched_inputs, images.image_sizes)):
            height = input_per_image.get("height", image_size[0])
            width = input_per_image.get("width", image_size[1])
            detector_r = detector_postprocess(detector_result, height, width)
            processed_result = {"instances": detector_r}
            if self.combine_on:
                sem_seg_r = sem_seg_postprocess(sem_seg_results[i], image_size,
                                                height, width)
                processed_result["sem_seg"] = sem_seg_r
            if "seg_thing_out" in basis_out:
                seg_thing_r = sem_seg_postprocess(basis_out["seg_thing_out"],
                                                  image_size, height, width)
                processed_result["sem_thing_seg"] = seg_thing_r
            if self.basis_module.visualize:
                processed_result["bases"] = basis_out["bases"]
            processed_results.append(processed_result)

            if self.combine_on:
                panoptic_r = combine_semantic_and_instance_outputs(
                    detector_r, sem_seg_r.argmax(dim=0),
                    self.combine_overlap_threshold,
                    self.combine_stuff_area_limit,
                    self.combine_instances_confidence_threshold)
                processed_results[-1]["panoptic_seg"] = panoptic_r
        return processed_results