Ejemplo n.º 1
0
    def inference(self, batched_inputs):

        images = self.preprocess_image(batched_inputs)
        features = self.backbone(images.tensor)

        if self.proposal_generator:
            proposals, _ = self.proposal_generator(images, features)
        else:
            raise NotImplementedError

        detector_results, pan_detector_results = self.roi_heads(
            images, features, proposals)
        sem_seg_results, _ = self.sem_seg_head(features)
        pan_seg_results, _ = self.panoptic_head(None, sem_seg_results,
                                                pan_detector_results)

        processed_results = []
        for sem_seg_result, detector_result, pan_seg_result, input_per_image, image_size in zip(
                sem_seg_results, detector_results, pan_seg_results,
                batched_inputs, images.image_sizes):
            processed_result = {}
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height,
                                            width)
            detector_r = detector_postprocess(detector_result, height, width)
            processed_result.update({
                "sem_seg": sem_seg_r,
                "instances": detector_r
            })

            if self.combine_on:
                panoptic_r = combine_semantic_and_instance_outputs(
                    detector_r,
                    sem_seg_r.argmax(dim=0),
                    self.combine_overlap_threshold,
                    self.combine_stuff_area_limit,
                    self.combine_instances_confidence_threshold,
                )
            else:
                pan_pred = sem_seg_postprocess(pan_seg_result["pan_logit"],
                                               image_size, height, width)
                del pan_seg_result["pan_logit"]
                pan_seg_result["pan_pred"] = pan_pred.argmax(dim=0)
                panoptic_r = pan_seg_postprocess(pan_seg_result,
                                                 sem_seg_r.argmax(dim=0),
                                                 self.stuff_num_classes,
                                                 self.stuff_area_limit)
            processed_result.update({"panoptic_seg": panoptic_r})

            processed_results.append(processed_result)
        return processed_results
Ejemplo n.º 2
0
        def f(batched_inputs, c2_inputs, c2_results):
            image_sizes = [[int(im[0]), int(im[1])] for im in c2_inputs["im_info"]]
            detector_results = assemble_rcnn_outputs_by_name(
                image_sizes, c2_results, force_mask_on=True
            )
            sem_seg_results = c2_results["sem_seg"]

            # copied from meta_arch/panoptic_fpn.py ...
            processed_results = []
            for sem_seg_result, detector_result, input_per_image, image_size in zip(
                sem_seg_results, detector_results, batched_inputs, image_sizes
            ):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
                detector_r = detector_postprocess(detector_result, height, width)

                processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r})

                if combine_on:
                    panoptic_r = combine_semantic_and_instance_outputs(
                        detector_r,
                        sem_seg_r.argmax(dim=0),
                        combine_overlap_threshold,
                        combine_stuff_area_limit,
                        combine_instances_confidence_threshold,
                    )
                    processed_results[-1]["panoptic_seg"] = panoptic_r
            return processed_results
Ejemplo n.º 3
0
    def __call__(
        self,
        batched_inputs: List[Dict[str, Any]],
        tensor_inputs: torch.Tensor,
        tensor_outputs: torch.Tensor,
    ) -> List[Dict[str, Any]]:
        """
        Rescales sem_seg logits to original image input resolution,
        and packages the logits into D2Go's expected output format.

        Args:
            inputs (List[Dict[str, Tensor]]): batched inputs from the dataloader.
            tensor_inputs (Tensor): tensorized inputs, e.g. from `PreprocessFunc`.
            tensor_outputs (Tensor): sem seg logits tensor from the model to process.

        Returns:
            processed_results (List[Dict]): List of D2Go output dicts ready to be used
                downstream in an Evaluator, for export, etc.
        """
        results = tensor_outputs  # nchw

        processed_results = []
        for result, input_per_image in zip(results, batched_inputs):
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            image_tensor_shape = input_per_image["image"].shape
            image_size = (image_tensor_shape[1], image_tensor_shape[2])

            # D2's sem_seg_postprocess rescales sem seg masks to the
            # provided original input resolution.
            r = sem_seg_postprocess(result, image_size, height, width)
            processed_results.append({"sem_seg": r})
        return processed_results
Ejemplo n.º 4
0
    def __call__(self, inputs, tensor_inputs, tensor_outputs):
        results = tensor_outputs  # nchw

        processed_results = []
        for result, input_per_image in zip(results, inputs):
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            image_tensor_shape = input_per_image["image"].shape
            image_size = (image_tensor_shape[1], image_tensor_shape[2])

            r = sem_seg_postprocess(result, image_size, height, width)
            processed_results.append({"sem_seg": r})
        return processed_results
Ejemplo n.º 5
0
    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
                Each item in the list contains the inputs for one image.
        For now, each item in the list is a dict that contains:
            image: Tensor, image in (C, H, W) format.
            sem_seg: semantic segmentation ground truth
            Other information that's included in the original dicts, such as:
                "height", "width" (int): the output resolution of the model, used in inference.
                    See :meth:`postprocess` for details.
        Returns:
            list[dict]: Each dict is the output for one input image.
                The dict contains one key "sem_seg" whose value is a
                Tensor of the output resolution that represents the
                per-pixel segmentation prediction.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images,
                                        self.backbone.size_divisibility)
        size = images.tensor.size()[-2:]
        features = self.backbone(images.tensor)

        if "sem_seg" in batched_inputs[0]:
            targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
            targets = ImageList.from_tensors(
                targets, self.backbone.size_divisibility,
                self.sem_seg_head.ignore_value).tensor
        else:
            targets = None
        results, losses = self.sem_seg_head(features, size, targets)

        if self.training:
            return losses

        processed_results = []
        for result, input_per_image, image_size in zip(results, batched_inputs,
                                                       images.image_sizes):
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            r = sem_seg_postprocess(result, image_size, height, width)
            processed_results.append({"sem_seg": r})
        return processed_results
Ejemplo n.º 6
0
    def forward(self, batched_inputs):
        # complete image
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images,
                                        self.inpaint_net.size_divisibility)
        # triplet input maps:
        # erased regions
        masks = [x["mask"].to(self.device) for x in batched_inputs]
        masks = ImageList.from_tensors(masks,
                                       self.inpaint_net.size_divisibility)
        # mask the input image with masks
        erased_ims = images.tensor * (1. - masks.tensor)
        # ones map
        ones_ims = [
            torch.ones_like(x["mask"].to(self.device)) for x in batched_inputs
        ]
        ones_ims = ImageList.from_tensors(ones_ims,
                                          self.inpaint_net.size_divisibility)
        # the conv layer use zero padding, this is used to indicate the image boundary

        # generation process
        input_tensor = torch.cat([erased_ims, ones_ims.tensor, masks.tensor],
                                 dim=1)
        coarse_inp, fine_inp, offset_flow = self.inpaint_net(
            input_tensor, masks.tensor)
        # offset_flow is used to visualize

        if self.training:
            raise NotImplementedError
        else:
            processed_results = []
            inpainted_im = erased_ims * (
                1. - masks.tensor) + fine_inp * masks.tensor
            for result, input_per_image, image_size in zip(
                    inpainted_im, batched_inputs, images.image_sizes):
                height = input_per_image.get("height")
                width = input_per_image.get("width")
                r = sem_seg_postprocess(result, image_size, height, width)
                # abuse semantic segmentation postprocess. it basically does some resize
                processed_results.append({"inpainted": r})
            return processed_results
Ejemplo n.º 7
0
    def forward(self, batched_inputs: list):
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images, size_divisibility=16)

        preds = self.run_model(images.tensor)

        if self.training:
            targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
            targets = ImageList.from_tensors(targets,
                                             size_divisibility=16,
                                             pad_value=255).tensor
            return dict(loss_sem_seg=F.cross_entropy(
                preds, targets, reduction="mean", ignore_index=255))

        processed_preds = []
        for pred, input_per_image, image_size in zip(preds, batched_inputs,
                                                     images.image_sizes):
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            r = sem_seg_postprocess(pred, image_size, height, width)
            processed_preds.append({"sem_seg": r})

        return processed_preds
Ejemplo n.º 8
0
    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:
                   * "image": Tensor, image in (C, H, W) format.
                   * "sem_seg": semantic segmentation ground truth
                   * "center": center points heatmap ground truth
                   * "offset": pixel offsets to center points ground truth
                   * Other information that's included in the original dicts, such as:
                     "height", "width" (int): the output resolution of the model (may be different
                     from input resolution), used in inference.
        Returns:
            list[dict]:
                each dict is the results for one image. The dict contains the following keys:

                * "panoptic_seg", "sem_seg": see documentation
                    :doc:`/tutorials/models` for the standard output format
                * "instances": available if ``predict_instances is True``. see documentation
                    :doc:`/tutorials/models` for the standard output format
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        # To avoid error in ASPP layer when input has different size.
        size_divisibility = (
            self.size_divisibility
            if self.size_divisibility > 0
            else self.backbone.size_divisibility
        )
        images = ImageList.from_tensors(images, size_divisibility)

        features = self.backbone(images.tensor)

        losses = {}
        if "sem_seg" in batched_inputs[0]:
            targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
            targets = ImageList.from_tensors(
                targets, size_divisibility, self.sem_seg_head.ignore_value
            ).tensor
            if "sem_seg_weights" in batched_inputs[0]:
                # The default D2 DatasetMapper may not contain "sem_seg_weights"
                # Avoid error in testing when default DatasetMapper is used.
                weights = [x["sem_seg_weights"].to(self.device) for x in batched_inputs]
                weights = ImageList.from_tensors(weights, size_divisibility).tensor
            else:
                weights = None
        else:
            targets = None
            weights = None
        sem_seg_results, sem_seg_losses = self.sem_seg_head(features, targets, weights)
        losses.update(sem_seg_losses)

        if "center" in batched_inputs[0] and "offset" in batched_inputs[0]:
            center_targets = [x["center"].to(self.device) for x in batched_inputs]
            center_targets = ImageList.from_tensors(
                center_targets, size_divisibility
            ).tensor.unsqueeze(1)
            center_weights = [x["center_weights"].to(self.device) for x in batched_inputs]
            center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor

            offset_targets = [x["offset"].to(self.device) for x in batched_inputs]
            offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor
            offset_weights = [x["offset_weights"].to(self.device) for x in batched_inputs]
            offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor
        else:
            center_targets = None
            center_weights = None

            offset_targets = None
            offset_weights = None

        center_results, offset_results, center_losses, offset_losses = self.ins_embed_head(
            features, center_targets, center_weights, offset_targets, offset_weights
        )
        losses.update(center_losses)
        losses.update(offset_losses)

        if self.training:
            return losses

        if self.benchmark_network_speed:
            return []

        processed_results = []
        for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip(
            sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes
        ):
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
            c = sem_seg_postprocess(center_result, image_size, height, width)
            o = sem_seg_postprocess(offset_result, image_size, height, width)
            # Post-processing to get panoptic segmentation.
            panoptic_image, _ = get_panoptic_segmentation(
                r.argmax(dim=0, keepdim=True),
                c,
                o,
                thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(),
                label_divisor=self.meta.label_divisor,
                stuff_area=self.stuff_area,
                void_label=-1,
                threshold=self.threshold,
                nms_kernel=self.nms_kernel,
                top_k=self.top_k,
            )
            # For semantic segmentation evaluation.
            processed_results.append({"sem_seg": r})
            panoptic_image = panoptic_image.squeeze(0)
            semantic_prob = F.softmax(r, dim=0)

            # Write results to disk:
            img = input_per_image["image"]
            from detectron2.utils.visualizer import Visualizer
            from detectron2.data.detection_utils import convert_image_to_rgb
            from PIL import Image 
            import os

            img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format).astype("uint8")
            img = np.array(Image.fromarray(img).resize((width, height)))
            v_panoptic = Visualizer(img, self.meta)
            v_panoptic = v_panoptic.draw_panoptic_seg_predictions(panoptic_image.cpu(), None)
            pan_img = v_panoptic.get_image()
            image_path = input_per_image['file_name'].split(os.sep)
            image_name = os.path.splitext(image_path[-1])[0] 
            Image.fromarray(pan_img).save(os.path.join('/home/ahabbas/projects/conseg/affinityNet/output_pdl/coco/eval_vis', image_name + '_panoptic.png'))

            # For panoptic segmentation evaluation.
            processed_results[-1]["panoptic_seg"] = (panoptic_image, None)
            # For instance segmentation evaluation.
            if self.predict_instances:
                instances = []
                panoptic_image_cpu = panoptic_image.cpu().numpy()
                for panoptic_label in np.unique(panoptic_image_cpu):
                    if panoptic_label == -1:
                        continue
                    pred_class = panoptic_label // self.meta.label_divisor
                    isthing = pred_class in list(
                        self.meta.thing_dataset_id_to_contiguous_id.values()
                    )
                    # Get instance segmentation results.
                    if isthing:
                        instance = Instances((height, width))
                        # Evaluation code takes continuous id starting from 0
                        instance.pred_classes = torch.tensor(
                            [pred_class], device=panoptic_image.device
                        )
                        mask = panoptic_image == panoptic_label
                        instance.pred_masks = mask.unsqueeze(0)
                        # Average semantic probability
                        sem_scores = semantic_prob[pred_class, ...]
                        sem_scores = torch.mean(sem_scores[mask])
                        # Center point probability
                        mask_indices = torch.nonzero(mask).float()
                        center_y, center_x = (
                            torch.mean(mask_indices[:, 0]),
                            torch.mean(mask_indices[:, 1]),
                        )
                        center_scores = c[0, int(center_y.item()), int(center_x.item())]
                        # Confidence score is semantic prob * center prob.
                        instance.scores = torch.tensor(
                            [sem_scores * center_scores], device=panoptic_image.device
                        )
                        # Get bounding boxes
                        instance.pred_boxes = BitMasks(instance.pred_masks).get_bounding_boxes()
                        instances.append(instance)
                if len(instances) > 0:
                    processed_results[-1]["instances"] = Instances.cat(instances)

        return processed_results
Ejemplo n.º 9
0
    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.

                For now, each item in the list is a dict that contains:

                * "image": Tensor, image in (C, H, W) format.
                * "instances": Instances
                * "sem_seg": semantic segmentation ground truth.
                * Other information that's included in the original dicts, such as:
                  "height", "width" (int): the output resolution of the model, used in inference.
                  See :meth:`postprocess` for details.

        Returns:
            list[dict]:
                each dict is the results for one image. The dict contains the following keys:

                * "instances": see :meth:`GeneralizedRCNN.forward` for its format.
                * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
                * "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`.
                  See the return value of
                  :func:`combine_semantic_and_instance_outputs` for its format.
        """

        image_path = [x['file_name'] for x in batched_inputs]
        if self.training:
            flips = [x['flip'] for x in batched_inputs]
        else:
            flips = None

        if self.training:
            exemplar_input = self.get_exemplar_input(image_path, sample_size=1)
            if exemplar_input is not None:
                l = len(batched_inputs)
                batched_inputs = batched_inputs + exemplar_input
                images, features, proposals, gt_instances, gt_integral_sem_seg, _, losses = self._forward(
                    batched_inputs)
                exemplar_features = {}
                for k, v in features.items():
                    exemplar_features[k] = v[l:]
                exemplar_gt_instances = gt_instances[l:]
                image_path = [x['file_name'] for x in batched_inputs]
                if self.training:
                    exemplar_flips = [x['flip'] for x in batched_inputs]
                else:
                    exemplar_flips = None
                with torch.no_grad():
                    exemplar_info = self.roi_heads.get_box_features(
                        exemplar_features, exemplar_gt_instances)
                detector_results, detector_losses = self.roi_heads(
                    images,
                    features,
                    proposals,
                    gt_instances,
                    gt_integral_sem_seg,
                    image_path=image_path,
                    flips=exemplar_flips,
                    exemplar_info=exemplar_info)
                del exemplar_info, exemplar_input
            else:
                exemplar_info = None
                images, features, proposals, gt_instances, gt_integral_sem_seg, _, losses = self._forward(
                    batched_inputs)
                detector_results, detector_losses = self.roi_heads(
                    images,
                    features,
                    proposals,
                    gt_instances,
                    gt_integral_sem_seg,
                    image_path=image_path,
                    flips=flips,
                    exemplar_info=exemplar_info)
        else:
            exemplar_info = None
            images, features, proposals, gt_instances, gt_integral_sem_seg, sem_seg_results, losses = self._forward(
                batched_inputs)
            detector_results, detector_losses = self.roi_heads(
                images,
                features,
                proposals,
                gt_instances,
                gt_integral_sem_seg,
                image_path=image_path,
                flips=flips,
                exemplar_info=exemplar_info)

        if self.training:
            losses.update({
                k: v * self.instance_loss_weight
                for k, v in detector_losses.items()
            })
            return losses

        processed_results = []
        for sem_seg_result, detector_result, input_per_image, image_size in zip(
                sem_seg_results, detector_results, batched_inputs,
                images.image_sizes):
            height = input_per_image.get("height", image_size[0])
            width = input_per_image.get("width", image_size[1])
            sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height,
                                            width)
            detector_r = detector_postprocess(detector_result, height, width)

            processed_results.append({
                "sem_seg": sem_seg_r,
                "instances": detector_r
            })

            if self.combine_on:
                panoptic_r = combine_semantic_and_instance_outputs(
                    detector_r,
                    sem_seg_r.argmax(dim=0),
                    self.combine_overlap_threshold,
                    self.combine_stuff_area_limit,
                    self.combine_instances_confidence_threshold,
                )
                processed_results[-1]["panoptic_seg"] = panoptic_r
        return processed_results
Ejemplo n.º 10
0
    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:
                   * "image": Tensor, image in (C, H, W) format.
                   * "sem_seg": semantic segmentation ground truth
                   * "center": center points heatmap ground truth
                   * "offset": pixel offsets to center points ground truth
                   * Other information that's included in the original dicts, such as:
                     "height", "width" (int): the output resolution of the model (may be different
                     from input resolution), used in inference.
        Returns:
            list[dict]:
              each dict is the results for one image. The dict contains the following keys:

                * "instances": see :meth:`GeneralizedRCNN.forward` for its format.
                * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
                * "panoptic_seg": see :func:`combine_semantic_and_instance_outputs` for its format.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        size_divisibility = self.backbone.size_divisibility
        images = ImageList.from_tensors(images, size_divisibility)

        features = self.backbone(images.tensor)

        losses = {}
        if "sem_seg" in batched_inputs[0]:
            targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
            targets = ImageList.from_tensors(
                targets, size_divisibility,
                self.sem_seg_head.ignore_value).tensor
            if "sem_seg_weights" in batched_inputs[0]:
                # The default D2 DatasetMapper may not contain "sem_seg_weights"
                # Avoid error in testing when default DatasetMapper is used.
                weights = [
                    x["sem_seg_weights"].to(self.device)
                    for x in batched_inputs
                ]
                weights = ImageList.from_tensors(weights,
                                                 size_divisibility).tensor
            else:
                weights = None
        else:
            targets = None
            weights = None
        sem_seg_results, sem_seg_losses = self.sem_seg_head(
            features, targets, weights)
        losses.update(sem_seg_losses)

        if "center" in batched_inputs[0] and "offset" in batched_inputs[0]:
            center_targets = [
                x["center"].to(self.device) for x in batched_inputs
            ]
            center_targets = ImageList.from_tensors(
                center_targets, size_divisibility).tensor.unsqueeze(1)
            center_weights = [
                x["center_weights"].to(self.device) for x in batched_inputs
            ]
            center_weights = ImageList.from_tensors(center_weights,
                                                    size_divisibility).tensor

            offset_targets = [
                x["offset"].to(self.device) for x in batched_inputs
            ]
            offset_targets = ImageList.from_tensors(offset_targets,
                                                    size_divisibility).tensor
            offset_weights = [
                x["offset_weights"].to(self.device) for x in batched_inputs
            ]
            offset_weights = ImageList.from_tensors(offset_weights,
                                                    size_divisibility).tensor
        else:
            center_targets = None
            center_weights = None

            offset_targets = None
            offset_weights = None

        center_results, offset_results, center_losses, offset_losses = self.ins_embed_head(
            features, center_targets, center_weights, offset_targets,
            offset_weights)
        losses.update(center_losses)
        losses.update(offset_losses)

        if self.training:
            return losses

        processed_results = []
        for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip(
                sem_seg_results, center_results, offset_results,
                batched_inputs, images.image_sizes):
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
            c = sem_seg_postprocess(center_result, image_size, height, width)
            o = sem_seg_postprocess(offset_result, image_size, height, width)
            # Post-processing to get panoptic segmentation.
            panoptic_image, _ = get_panoptic_segmentation(
                r.argmax(dim=0, keepdim=True),
                c,
                o,
                thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(),
                label_divisor=self.meta.label_divisor,
                stuff_area=self.stuff_area,
                void_label=-1,
                threshold=self.threshold,
                nms_kernel=self.nms_kernel,
                top_k=self.top_k,
            )
            # For semantic segmentation evaluation.
            processed_results.append({"sem_seg": r})
            panoptic_image = panoptic_image.squeeze(0)
            semantic_prob = F.softmax(r, dim=0)
            # For panoptic segmentation evaluation.
            processed_results[-1]["panoptic_seg"] = (panoptic_image, None)
            # For instance segmentation evaluation.
            if self.predict_instances:
                instances = []
                panoptic_image_cpu = panoptic_image.cpu().numpy()
                for panoptic_label in np.unique(panoptic_image_cpu):
                    if panoptic_label == -1:
                        continue
                    pred_class = panoptic_label // self.meta.label_divisor
                    isthing = pred_class in list(
                        self.meta.thing_dataset_id_to_contiguous_id.values())
                    # Get instance segmentation results.
                    if isthing:
                        instance = Instances((height, width))
                        # Evaluation code takes continuous id starting from 0
                        instance.pred_classes = torch.tensor(
                            [pred_class], device=panoptic_image.device)
                        mask = panoptic_image == panoptic_label
                        instance.pred_masks = mask.unsqueeze(0)
                        # Average semantic probability
                        sem_scores = semantic_prob[pred_class, ...]
                        sem_scores = torch.mean(sem_scores[mask])
                        # Center point probability
                        mask_indices = torch.nonzero(mask).float()
                        center_y, center_x = (
                            torch.mean(mask_indices[:, 0]),
                            torch.mean(mask_indices[:, 1]),
                        )
                        center_scores = c[0,
                                          int(center_y.item()),
                                          int(center_x.item())]
                        # Confidence score is semantic prob * center prob.
                        instance.scores = torch.tensor(
                            [sem_scores * center_scores],
                            device=panoptic_image.device)
                        # Get bounding boxes
                        instance.pred_boxes = BitMasks(
                            instance.pred_masks).get_bounding_boxes()
                        instances.append(instance)
                if len(instances) > 0:
                    processed_results[-1]["instances"] = Instances.cat(
                        instances)

        return processed_results
Ejemplo n.º 11
0
    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.

                For now, each item in the list is a dict that contains:

                * "image": Tensor, image in (C, H, W) format.
                * "instances": Instances
                * "sem_seg": semantic segmentation ground truth.
                * Other information that's included in the original dicts, such as:
                  "height", "width" (int): the output resolution of the model, used in inference.
                  See :meth:`postprocess` for details.

        Returns:
            list[dict]:
                each dict is the results for one image. The dict contains the following keys:

                * "instances": see :meth:`GeneralizedRCNN.forward` for its format.
                * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
                * "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`.
                  See the return value of
                  :func:`combine_semantic_and_instance_outputs` for its format.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        images = ImageList.from_tensors(images,
                                        self.backbone.size_divisibility)
        features = self.backbone(images.tensor)

        if "proposals" in batched_inputs[0]:
            proposals = [
                x["proposals"].to(self.device) for x in batched_inputs
            ]
            proposal_losses = {}

        if "sem_seg" in batched_inputs[0]:
            gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs]
            gt_sem_seg = ImageList.from_tensors(
                gt_sem_seg, self.backbone.size_divisibility,
                self.sem_seg_head.ignore_value).tensor
        else:
            gt_sem_seg = None
        sem_seg_results, sem_seg_losses = self.sem_seg_head(
            features, gt_sem_seg)

        if "instances" in batched_inputs[0]:
            gt_instances = [
                x["instances"].to(self.device) for x in batched_inputs
            ]
        else:
            gt_instances = None
        if self.proposal_generator:
            proposals, proposal_losses = self.proposal_generator(
                images, features, gt_instances)
        detector_results, detector_losses = self.roi_heads(
            images, features, proposals, gt_instances)

        if self.training:
            losses = {}
            losses.update(sem_seg_losses)
            losses.update({
                k: v * self.instance_loss_weight
                for k, v in detector_losses.items()
            })
            losses.update(proposal_losses)
            return losses

        processed_results = []
        for sem_seg_result, detector_result, input_per_image, image_size in zip(
                sem_seg_results, detector_results, batched_inputs,
                images.image_sizes):
            height = input_per_image.get("height", image_size[0])
            width = input_per_image.get("width", image_size[1])
            sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height,
                                            width)
            detector_r = detector_postprocess(detector_result, height, width)

            processed_results.append({
                "sem_seg": sem_seg_r,
                "instances": detector_r
            })

            if self.combine_on:
                panoptic_r = combine_semantic_and_instance_outputs(
                    detector_r,
                    sem_seg_r.argmax(dim=0),
                    self.combine_overlap_threshold,
                    self.combine_stuff_area_limit,
                    self.combine_instances_confidence_threshold,
                )
                processed_results[-1]["panoptic_seg"] = panoptic_r
        return processed_results
Ejemplo n.º 12
0
    def forward(self, batched_inputs):
        # complete image
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images,
                                        self.generator.size_divisibility)
        # triplet input maps:
        # erased regions
        masks = [x["mask"].to(self.device) for x in batched_inputs]
        masks = ImageList.from_tensors(masks, self.generator.size_divisibility)
        # mask the input image with masks
        erased_ims = images.tensor * (1. - masks.tensor)
        # ones map
        ones_ims = [
            torch.ones_like(x["mask"].to(self.device)) for x in batched_inputs
        ]
        ones_ims = ImageList.from_tensors(ones_ims,
                                          self.generator.size_divisibility)
        # the conv layer use zero padding, this is used to indicate the image boundary

        # generation process
        input_tensor = torch.cat([erased_ims, ones_ims.tensor, masks.tensor],
                                 dim=1)
        coarse_inp, fine_inp, offset_flow = self.generator(
            input_tensor, masks.tensor)
        # offset_flow is used to visualize

        if self.training:
            # reconstruction loss
            losses = {}
            losses["loss_coarse_rec"] = self.loss_rec_weight * torch.abs(
                images.tensor - coarse_inp).mean()
            losses["loss_fine_rec"] = self.loss_rec_weight * torch.abs(
                images.tensor - fine_inp).mean()

            # discriminator
            real_and_fake_ims = torch.cat([images.tensor, fine_inp], dim=0)
            real_and_fake_masks = torch.cat([masks.tensor, masks.tensor],
                                            dim=0)  # append masks
            disc_pred = self.discriminator(
                torch.cat([real_and_fake_ims, real_and_fake_masks], dim=1))
            pred_for_real, pred_for_fake = torch.split(disc_pred,
                                                       disc_pred.size(0) // 2,
                                                       dim=0)
            # TODO: perhaps configure the loss function
            g_loss, d_loss = self.get_discriminator_hinge_loss(
                pred_for_real, pred_for_fake)
            losses['loss_gen'] = self.loss_gan_weight * g_loss
            losses['loss_disc'] = d_loss
            losses["generator_loss"] = sum([
                losses[k]
                for k in ["loss_coarse_rec", "loss_fine_rec", "loss_gen"]
            ])
            losses["discriminator_loss"] = losses["loss_disc"]
            return losses
        else:
            processed_results = []
            inpainted_im = erased_ims * (
                1. - masks.tensor) + fine_inp * masks.tensor
            for result, input_per_image, image_size in zip(
                    inpainted_im, batched_inputs, images.image_sizes):
                height = input_per_image.get("height")
                width = input_per_image.get("width")
                r = sem_seg_postprocess(result, image_size, height, width)
                # abuse semantic segmentation postprocess. it basically does some resize
                processed_results.append({"inpainted": r})
            return processed_results
Ejemplo n.º 13
0
    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.

        For now, each item in the list is a dict that contains:
            image: Tensor, image in (C, H, W) format.
            instances: Instances
            sem_seg: semantic segmentation ground truth.
            Other information that's included in the original dicts, such as:
                "height", "width" (int): the output resolution of the model, used in inference.
                    See :meth:`postprocess` for details.

        Returns:
            list[dict]: each dict is the results for one image. The dict
                contains the following keys:
                "instances": see :meth:`GeneralizedRCNN.forward` for its format.
                "sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
                "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`.
                    See the return value of
                    :func:`combine_semantic_and_instance_outputs` for its format.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images,
                                        self.backbone.size_divisibility)
        features = self.backbone(images.tensor)

        if self.combine_on:
            if "sem_seg" in batched_inputs[0]:
                gt_sem = [x["sem_seg"].to(self.device) for x in batched_inputs]
                gt_sem = ImageList.from_tensors(
                    gt_sem, self.backbone.size_divisibility,
                    self.panoptic_module.ignore_value).tensor
            else:
                gt_sem = None
            sem_seg_results, sem_seg_losses = self.panoptic_module(
                features, gt_sem)

        if "basis_sem" in batched_inputs[0]:
            basis_sem = [
                x["basis_sem"].to(self.device) for x in batched_inputs
            ]
            basis_sem = ImageList.from_tensors(basis_sem,
                                               self.backbone.size_divisibility,
                                               0).tensor
        else:
            basis_sem = None
        basis_out, basis_losses = self.basis_module(features, basis_sem)

        if "instances" in batched_inputs[0]:
            gt_instances = [
                x["instances"].to(self.device) for x in batched_inputs
            ]
        else:
            gt_instances = None
        proposals, proposal_losses = self.proposal_generator(
            images, features, gt_instances, self.top_layer)
        detector_results, detector_losses = self.blender(
            basis_out["bases"], proposals, gt_instances)

        if self.training:
            losses = {}
            losses.update(basis_losses)
            losses.update({
                k: v * self.instance_loss_weight
                for k, v in detector_losses.items()
            })
            losses.update(proposal_losses)
            if self.combine_on:
                losses.update(sem_seg_losses)
            return losses

        processed_results = []
        for i, (detector_result, input_per_image, image_size) in enumerate(
                zip(detector_results, batched_inputs, images.image_sizes)):
            height = input_per_image.get("height", image_size[0])
            width = input_per_image.get("width", image_size[1])
            detector_r = detector_postprocess(detector_result, height, width)
            processed_result = {"instances": detector_r}
            if self.combine_on:
                sem_seg_r = sem_seg_postprocess(sem_seg_results[i], image_size,
                                                height, width)
                processed_result["sem_seg"] = sem_seg_r
            if "seg_thing_out" in basis_out:
                seg_thing_r = sem_seg_postprocess(basis_out["seg_thing_out"],
                                                  image_size, height, width)
                processed_result["sem_thing_seg"] = seg_thing_r
            if self.basis_module.visualize:
                processed_result["bases"] = basis_out["bases"]
            processed_results.append(processed_result)

            if self.combine_on:
                panoptic_r = combine_semantic_and_instance_outputs(
                    detector_r, sem_seg_r.argmax(dim=0),
                    self.combine_overlap_threshold,
                    self.combine_stuff_area_limit,
                    self.combine_instances_confidence_threshold)
                processed_results[-1]["panoptic_seg"] = panoptic_r
        return processed_results