Example #1
0
    def forward(self, batched_inputs):
        """
        Args:
            Same as in :class:`GeneralizedRCNN.forward`

        Returns:
            list[dict]:
                Each dict is the output for one input image.
                The dict contains one key "proposals" whose value is a
                :class:`Instances` with keys "proposal_boxes" and "objectness_logits".
        """
        with timer.env("pre_process"):
            images = [x["image"].to(self.device) for x in batched_inputs]
            images = [self.normalizer(x) for x in images]
            images = ImageList.from_tensors(images,
                                            self.backbone.size_divisibility)

        with timer.env('backbone'):
            features = self.backbone(images.tensor)

        if "instances" in batched_inputs[0]:
            gt_instances = [
                x["instances"].to(self.device) for x in batched_inputs
            ]
        elif "targets" in batched_inputs[0]:
            log_first_n(
                logging.WARN,
                "'targets' in the model inputs is now renamed to 'instances'!",
                n=10)
            gt_instances = [
                x["targets"].to(self.device) for x in batched_inputs
            ]
        else:
            gt_instances = None

        with timer.env('fcos'):
            proposals, proposal_losses = self.proposal_generator(
                images, features, gt_instances)
        # In training, the proposals are not useful at all but we generate them anyway.
        # This makes RPN-only models about 5% slower.
        if self.training:
            return proposal_losses

        processed_results = []
        with timer.env('post_process'):
            for results_per_image, input_per_image, image_size in zip(
                    proposals, batched_inputs, images.image_sizes):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                r = detector_postprocess(results_per_image, height, width)
                processed_results.append({"proposals": r})
        return processed_results
Example #2
0
    def single_test(self, batched_inputs):
        assert len(batched_inputs) == 1
        with timer.env("preprocess"):
            images = batched_inputs[0]["image"].to(self.device)
            images = self.normalizer(images)
            images = ImageList.from_tensors([images],
                                            self.backbone.size_divisibility)

        with timer.env("backbone"):
            features = self.backbone(images.tensor)

        gt_instances = None
        gt_sem_seg = None

        with timer.env("fcose"):
            proposals, proposal_losses = self.proposal_generator(
                images, features, gt_instances)

        if self.mask_result_src != "BOX":
            edge_map, head_losses, proposals = self.refinement_head(
                features, proposals,
                (gt_sem_seg, [gt_instances, images.image_sizes]))

        with timer.env("postprocess"):
            height = batched_inputs[0].get("height", images.image_sizes[0][0])
            width = batched_inputs[0].get("width", images.image_sizes[0][1])
            instance_r = detector_postprocess(
                self.semantic_filter,
                self.semantic_filter_th,
                self.mask_result_src,
                proposals[0],
                height,
                width,
                self.roi_size,
                self.need_concave_hull,
                self.re_compute_box,
            )
            processed_results = [{"instances": instance_r}]
            return processed_results
Example #3
0
def detector_postprocess(semantic_filter, semantic_filter_th, mask_result_src,
                         results, output_height, output_width, roi_size,
                         need_concave_hull, re_comp_box):
    """
    Resize the output instances.
    The input images are often resized when entering an object detector.
    As a result, we often need the outputs of the detector in a different
    resolution from its inputs.

    This function will resize the raw outputs of an R-CNN detector
    to produce outputs according to the desired output resolution.

    Args:
        results (Instances): the raw outputs from the detector.
            `results.image_size` contains the input image resolution the detector sees.
            This object might be modified in-place.
        output_height, output_width: the desired output resolution.

    Returns:
        Instances: the resized output from the model, based on the output resolution
    """
    # the results.image_size here is the one the model saw, typically (800, xxxx)

    # with timer.env('postprocess_sub1_get'):
    scale_x, scale_y = (output_width / results.image_size[1],
                        output_height / results.image_size[0])
    results = Instances((output_height, output_width), **results.get_fields())

    if results.has("pred_boxes"):
        output_boxes = results.pred_boxes
    elif results.has("proposal_boxes"):
        output_boxes = results.proposal_boxes

    # with timer.env('postprocess_sub2_scale'):
    output_boxes.scale(scale_x, scale_y)
    # now the results.image_size is the one of raw input image
    # with timer.env('postprocess_sub3_clip'):
    output_boxes.clip(results.image_size)

    # with timer.env('postprocess_sub4_filter'):
    results = results[output_boxes.nonempty()]

    # with timer.env('postprocess_cp2'):
    if results.has("pred_polys"):
        if results.has("pred_path"):
            with timer.env('extra'):
                snake_path = results.pred_path
                for i in range(snake_path.size(1)):  # number of evolution
                    current_poly = PolygonPoints(snake_path[:, i, :, :])
                    current_poly.scale(scale_x, scale_y)
                    current_poly.clip(results.image_size)
                    snake_path[:, i, :, :] = current_poly.tensor

        # TODO: Note that we did not scale exts (no need if not for evaluation)
        if results.has("ext_points"):
            results.ext_points.scale(scale_x, scale_y)

        results.pred_polys.scale(scale_x, scale_y)

        if re_comp_box:
            results.pred_boxes = Boxes(results.pred_polys.get_box())

        # results.pred_polys.clip(results.image_size)
        # results.pred_masks = get_polygon_rles(results.pred_polys.flatten(),
        #                                       (output_height, output_width))

        return results

    # if semantic_filter and results.has("ext_points"):
    #     if len(results) > 0:
    #         output_ext_points = results.ext_points
    #         keep_on_edge = output_ext_points.onedge(edge_map,
    #                                                 (output_height, output_width),
    #                                                 threshold=semantic_filter_th)
    #         re_weight = keep_on_edge.float() * 0.1 + 0.9
    #         results.scores *= re_weight

    if mask_result_src == 'NO':
        return results

    if mask_result_src == 'BOX':
        results.pred_masks = get_bbox_rles(results.pred_boxes.tensor,
                                           (output_height, output_width))

    elif results.has("ext_points"):
        # directly from extreme points to get these results as masks
        results.ext_points.scale(scale_x, scale_y)
        results.ext_points.fit_to_box()

        if mask_result_src == 'OCT_BIT':
            results.pred_masks = get_octagon_mask(
                results.ext_points.get_octagons(),
                (output_height, output_width))
        elif mask_result_src == 'OCT_RLE':
            results.pred_masks = get_octagon_rles(
                results.ext_points.get_octagons(),
                (output_height, output_width))
        # elif mask_result_src == 'MASK':
        #     aligned_ext_pts = results.ext_points.align(roi_size).cpu()
        #     batch_inds = torch.tensor([[0.]], device=results.ext_points.device).expand(len(results), 1)
        #     rois = torch.cat([batch_inds, results.pred_boxes.tensor], dim=1)  # Nx5
        #     roi_edge_map = ROIAlign(
        #         (roi_size, roi_size), 1.0, 0, aligned=False
        #     ).forward(edge_map[None, None, :, :].clone(), rois).squeeze(1)
        #     roi_edge_map = roi_edge_map.cpu()   # (D, roi_size, roi_size)
        #
        #     pois = get_masks(roi_edge_map, aligned_ext_pts, results.pred_boxes.tensor.cpu(), roi_size)
        #
        #     results.pred_masks = get_polygon_rles(pois, (output_height, output_width))

    return results
Example #4
0
    def forward(self, batched_inputs):
        """
        Args:
            Same as in :class:`GeneralizedRCNN.forward`

        Returns:
            list[dict]:
                Each dict is the output for one input image.
                The dict contains one key "proposals" whose value is a
                :class:`Instances` with keys "proposal_boxes" and "objectness_logits".
        """
        if not self.training and not self.visualize_path:
            return self.single_test(batched_inputs)

        with timer.env("preprocess"):
            images = [x["image"].to(self.device) for x in batched_inputs]
            images = [self.normalizer(x) for x in images]
            images = ImageList.from_tensors(images,
                                            self.backbone.size_divisibility)

        with timer.env("backbone"):
            features = self.backbone(images.tensor)

        if "instances" in batched_inputs[0]:
            gt_instances = [
                x["instances"].to(self.device) for x in batched_inputs
            ]
        elif "targets" in batched_inputs[0]:
            log_first_n(
                logging.WARN,
                "'targets' in the model inputs is now renamed to 'instances'!",
                n=10,
            )
            gt_instances = [
                x["targets"].to(self.device) for x in batched_inputs
            ]
        else:
            gt_instances = None

        if "sem_seg" in batched_inputs[0]:
            gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs]
            gt_sem_seg = ImageList.from_tensors(
                gt_sem_seg,
                self.backbone.size_divisibility,
                self.refinement_head.ignore_value,
            ).tensor
        else:
            gt_sem_seg = None

        with timer.env("fcose"):
            proposals, proposal_losses = self.proposal_generator(
                images, features, gt_instances)
        edge_map, head_losses, proposals = self.refinement_head(
            features, proposals,
            (gt_sem_seg, [gt_instances, images.image_sizes]))

        # In training, the proposals are not useful at all in RPN models; but not here
        # This makes RPN-only models about 5% slower.
        if self.training:
            timer.reset()
            proposal_losses.update(head_losses)
            return proposal_losses

        processed_results = []

        with timer.env("postprocess"):
            for per_edge_map, results_per_image, input_per_image, image_size in zip(
                    edge_map, proposals, batched_inputs, images.image_sizes):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                # TODO (OPT): NO need for interpolate then back for real speed test
                with timer.env("extra"):
                    edge_map_r = edge_map_postprocess(per_edge_map, image_size,
                                                      height, width)
                instance_r = detector_postprocess(
                    self.semantic_filter,
                    self.semantic_filter_th,
                    self.mask_result_src,
                    results_per_image,
                    height,
                    width,
                    self.roi_size,
                    self.need_concave_hull,
                    self.re_compute_box,
                )
                processed_results.append(
                    {
                        "instances": instance_r,
                        "edge_map": edge_map_r
                    }, )
        return processed_results
Example #5
0
    def forward(self, features, pred_instances=None, targets=None):
        if self.edge_on:
            with timer.env("pfpn_back"):
                for i, f in enumerate(self.in_features):
                    if i == 0:
                        x = self.scale_heads[i](features[f])
                    else:
                        x = x + self.scale_heads[i](features[f])

        if self.edge_on:
            with timer.env("edge"):
                pred_logits = self.predictor(x)
                pred_edge = pred_logits.sigmoid()
                if self.attention:
                    # print('pred edge', pred_edge)
                    att_map = self.attender(
                        1 - pred_edge
                    )  # regions that need evolution

        if self.training:
            edge_target = targets[0]
            if self.edge_in:
                edge_prior = targets[0].unsqueeze(1).float().clone()  # (B, 1, H, W)
                edge_prior[edge_prior == self.ignore_value] = 0  # remove ignore value

                edge_prior = self.mean_filter(edge_prior)
                edge_prior = F.interpolate(
                    edge_prior,
                    scale_factor=1 / self.common_stride,
                    mode="bilinear",
                    align_corners=False,
                )
                edge_prior[edge_prior > 0] = 1

                if self.strong_feat:
                    snake_input = torch.cat([edge_prior, x], dim=1)
                else:
                    snake_input = torch.cat([edge_prior, features["p2"]], dim=1)
            else:
                if self.strong_feat:
                    snake_input = x
                else:
                    snake_input = features["p2"]

            if self.edge_on:
                pred_edge_full = F.interpolate(
                    pred_edge,
                    scale_factor=self.common_stride,
                    mode="bilinear",
                    align_corners=False,
                )

            if self.selective_refine:
                edge_prior = targets[0].unsqueeze(1).float().clone()  # (B, 1, H, W)
                edge_prior[edge_prior == self.ignore_value] = 0  # remove ignore value
                edge_prior = self.dilate_filter(edge_prior)
                # edge_prior = self.dilate_filter(edge_prior)
                # edge_target = edge_prior.clone()
                edge_prior[edge_prior > 0] = 1
                edge_prior = F.interpolate(
                    edge_prior,
                    scale_factor=1 / self.common_stride,
                    mode="bilinear",
                    align_corners=False,
                )
                if self.strong_feat:
                    snake_input = torch.cat([edge_prior, x], dim=1)
                else:
                    if self.pred_edge:
                        snake_input = torch.cat(
                            [edge_prior, pred_logits, features["p2"]], dim=1
                        )
                    else:
                        snake_input = torch.cat([edge_prior, features["p2"]], dim=1)

            if self.attention:
                if self.strong_feat:
                    snake_input = torch.cat([att_map, x], dim=1)
                else:
                    # dont cater pred_edge option now
                    snake_input = torch.cat([att_map, features["p2"]], dim=1)

            ### Quick fix for batches that do not have poly after filtering
            _, poly_loss = self.refine_head(snake_input, None, targets[1])

            if self.edge_on:
                edge_loss = self.loss(pred_edge_full, edge_target) * self.loss_weight
                poly_loss.update(
                    {
                        "loss_edge_det": edge_loss,
                    }
                )

            return [], poly_loss, []
        else:
            if self.edge_in or self.selective_refine:
                if self.edge_map_thre > 0:
                    pred_edge = (pred_edge > self.edge_map_thre).float()

                if "edge" in self.gt_input:
                    assert targets[0] is not None
                    pred_edge = targets[0].unsqueeze(1).float().clone()
                    pred_edge[pred_edge == self.ignore_value] = 0  # remove ignore value

                    if self.selective_refine:
                        pred_edge = self.dilate_filter(pred_edge)
                        # pred_edge = self.dilate_filter(pred_edge)

                    pred_edge = F.interpolate(
                        pred_edge,
                        scale_factor=1 / self.common_stride,
                        mode="bilinear",
                        align_corners=False,
                    )

                    pred_edge[pred_edge > 0] = 1
                if self.strong_feat:
                    snake_input = torch.cat([pred_edge, x], dim=1)
                else:
                    snake_input = torch.cat([pred_edge, features["p2"]], dim=1)
            else:
                if self.strong_feat:
                    snake_input = x
                else:
                    snake_input = features["p2"]

            if self.attention:
                if self.strong_feat:
                    snake_input = torch.cat([att_map, x], dim=1)
                else:
                    # dont cater pred_edge option now
                    snake_input = torch.cat([att_map, features["p2"]], dim=1)

            if "instance" in self.gt_input:
                assert targets[1][0] is not None

                for im_i in range(len(targets[1][0])):
                    gt_instances_per_im = targets[1][0][im_i]
                    bboxes = gt_instances_per_im.gt_boxes.tensor
                    instances_per_im = Instances(pred_instances[im_i]._image_size)
                    instances_per_im.pred_boxes = Boxes(bboxes)
                    instances_per_im.pred_classes = gt_instances_per_im.gt_classes
                    instances_per_im.scores = torch.ones_like(
                        gt_instances_per_im.gt_classes, device=bboxes.device
                    )
                    if gt_instances_per_im.has("gt_masks"):
                        gt_masks = gt_instances_per_im.gt_masks
                        ext_pts_off = self.refine_head.get_simple_extreme_points(
                            gt_masks.polygons
                        ).to(bboxes.device)
                        ex_t = torch.stack(
                            [ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2
                        )
                        ex_l = torch.stack(
                            [bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2
                        )
                        ex_b = torch.stack(
                            [ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2
                        )
                        ex_r = torch.stack(
                            [bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2
                        )
                        instances_per_im.ext_points = ExtremePoints(
                            torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1)
                        )

                        # TODO: NOTE: Test for theoretic limit. #####
                        # contours = self.refine_head.get_simple_contour(gt_masks)
                        # poly_sample_targets = []
                        # for i, cnt in enumerate(contours):
                        #     if cnt is None:
                        #         xmin, ymin = bboxes[:, 0], bboxes[:, 1]  # (n,)
                        #         xmax, ymax = bboxes[:, 2], bboxes[:, 3]  # (n,)
                        #         box = [
                        #             xmax, ymin, xmin, ymin, xmin, ymax, xmax, ymax
                        #         ]
                        #         box = torch.stack(box, dim=1).view(-1, 4, 2)
                        #         sampled_box = self.refine_head.uniform_upsample(box[None],
                        #                                                         self.refine_head.num_sampling)
                        #         poly_sample_targets.append(sampled_box[i])
                        #         # print(sampled_box.shape)
                        #         continue
                        #
                        #     # 1) uniform-sample
                        #     oct_sampled_targets = self.refine_head.uniform_sample(cnt,
                        #                                                           len(cnt) * self.refine_head.num_sampling)  # (big, 2)
                        #     tt_idx = np.random.randint(len(oct_sampled_targets))
                        #     oct_sampled_targets = np.roll(oct_sampled_targets, -tt_idx, axis=0)[::len(cnt)]
                        #     oct_sampled_targets = torch.tensor(oct_sampled_targets, device=bboxes.device)
                        #     poly_sample_targets.append(oct_sampled_targets)
                        #     # print(oct_sampled_targets.shape)
                        #
                        #     # 2) polar-sample
                        #     # ...
                        # poly_sample_targets = torch.stack(poly_sample_targets, dim=0)
                        # instances_per_im.pred_polys = PolygonPoints(poly_sample_targets)
                        # TODO: NOTE: Test for theoretic limit. #####

                    pred_instances[im_i] = instances_per_im

            new_instances, _ = self.refine_head(snake_input, pred_instances, None)
            # new_instances = pred_instances
            if not self.edge_on:
                pred_edge = torch.rand(1, 1, 5, 5, device=snake_input.device)

            if self.attention:
                pred_edge = att_map

            return pred_edge, {}, new_instances