Beispiel #1
0
    def forward_for_single_feature_map(self, locations, box_cls, reg_pred,
                                       ext_pred, ctrness, image_sizes):
        N, C, H, W = box_cls.shape

        # put in the same format as locations
        box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1)
        box_cls = box_cls.reshape(N, -1, C).sigmoid()
        box_regression = reg_pred.view(N, 4, H, W).permute(0, 2, 3, 1)
        box_regression = box_regression.reshape(N, -1, 4)
        if ext_pred is not None:
            ext_regression = ext_pred.view(N, 4, H, W).permute(0, 2, 3, 1)
            ext_regression = ext_regression.reshape(N, -1, 4)
        ctrness = ctrness.view(N, 1, H, W).permute(0, 2, 3, 1)
        ctrness = ctrness.reshape(N, -1).sigmoid()

        # if self.thresh_with_ctr is True, we multiply the classification
        # scores with centerness scores before applying the threshold.
        if self.thresh_with_ctr:
            box_cls = box_cls * ctrness[:, :, None]
        candidate_inds = box_cls > self.pre_nms_thresh
        pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)

        if not self.thresh_with_ctr:
            box_cls = box_cls * ctrness[:, :, None]

        results = []
        for i in range(N):
            per_box_cls = box_cls[i]
            per_candidate_inds = candidate_inds[i]
            per_box_cls = per_box_cls[per_candidate_inds]

            per_candidate_nonzeros = per_candidate_inds.nonzero()
            per_box_loc = per_candidate_nonzeros[:, 0]
            per_class = per_candidate_nonzeros[:, 1]

            per_box_regression = box_regression[i]
            per_box_regression = per_box_regression[per_box_loc]
            per_locations = locations[per_box_loc]

            if ext_pred is not None:
                per_ext_regression = ext_regression[i]
                per_ext_regression = per_ext_regression[per_box_loc]

            per_pre_nms_top_n = pre_nms_top_n[i]

            if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
                per_box_cls, top_k_indices = \
                    per_box_cls.topk(per_pre_nms_top_n, sorted=False)
                per_class = per_class[top_k_indices]
                per_box_regression = per_box_regression[top_k_indices]
                if ext_pred is not None:
                    per_ext_regression = per_ext_regression[top_k_indices]
                per_locations = per_locations[top_k_indices]

            detections = torch.stack([
                per_locations[:, 0] - per_box_regression[:, 0],
                per_locations[:, 1] - per_box_regression[:, 1],
                per_locations[:, 0] + per_box_regression[:, 2],
                per_locations[:, 1] + per_box_regression[:, 3],
            ],
                                     dim=1)

            boxlist = Instances(image_sizes[i])
            # print('size 2)', image_sizes[i])
            boxlist.pred_boxes = Boxes(detections)
            boxlist.scores = torch.sqrt(per_box_cls)
            boxlist.pred_classes = per_class
            boxlist.locations = per_locations
            if ext_pred is not None:
                boxlist.ext_points = ExtremePoints.from_boxes(
                    boxlist.pred_boxes, per_ext_regression, per_locations)

            results.append(boxlist)

        return results
Beispiel #2
0
    def forward(self, batched_inputs):
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images, self.backbone.size_divisibility)

        features = self.backbone(images.tensor)

        if "instances" in batched_inputs[0] :
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        elif "targets" in batched_inputs[0]:
            log_first_n(
                logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
            )
            gt_instances = [x["targets"].to(self.device) for x in batched_inputs]
        else:
            gt_instances = None

        proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)

        if not self.training:
            if 'instance' in self.gt_input:
                assert gt_instances is not None

                for im_i in range(len(gt_instances)):
                    gt_instances_per_im = gt_instances[im_i]
                    bboxes = gt_instances_per_im.gt_boxes.tensor
                    instances_per_im = Instances(proposals[im_i]._image_size)
                    instances_per_im.pred_boxes = Boxes(bboxes)
                    instances_per_im.pred_classes = gt_instances_per_im.gt_classes
                    instances_per_im.scores = torch.ones_like(gt_instances_per_im.gt_classes).to(bboxes.device)

                    if gt_instances_per_im.has("gt_masks"):
                        gt_masks = gt_instances_per_im.gt_masks
                        ext_pts_off = self.refinement_head.refine_head.get_simple_extreme_points(
                            gt_masks.polygons).to(bboxes.device)
                        ex_t = torch.stack([ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2)
                        ex_l = torch.stack([bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2)
                        ex_b = torch.stack([ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2)
                        ex_r = torch.stack([bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2)
                        instances_per_im.ext_points = ExtremePoints(
                            torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1))
                    else:
                        quad = self.refinement_head.refine_head.get_quadrangle(bboxes).view(-1, 4, 2)
                        instances_per_im.ext_points = ExtremePoints(quad)

                    proposals[im_i] = instances_per_im

        head_losses, proposals = self.refinement_head(features, proposals, gt_instances)

        # In training, the proposals are not useful at all in RPN models; but not here
        # This makes RPN-only models about 5% slower.
        if self.training:
            proposal_losses.update(head_losses)
            return proposal_losses

        processed_results = []
        for results_per_image, input_per_image, image_size in zip(
                proposals, batched_inputs, images.image_sizes
        ):
            height = input_per_image.get("height", image_size[0])
            width = input_per_image.get("width", image_size[1])
            instance_r = detector_postprocess(results_per_image,
                                              height,
                                              width)
            processed_results.append(
                {"instances": instance_r}
            )

        return processed_results
Beispiel #3
0
    def forward(self, features, pred_instances=None, targets=None):

        for i, f in enumerate(self.in_features):
            if i == 0:
                x = self.scale_heads[i](features[f])
            else:
                x = x + self.scale_heads[i](features[f])

        pred_logits = self.predictor(x)
        pred_edge = pred_logits.sigmoid()

        att_map = self.attender(1 - pred_edge)  # regions that need evolution

        if self.training:
            edge_target = targets[0]
            snake_input = x
            pred_edge_full = F.interpolate(
                pred_edge,
                scale_factor=self.common_stride,
                mode="bilinear",
                align_corners=False,
            )
            snake_input = torch.cat([att_map, x], dim=1)

            # Quick fix for batches that do not have poly after filtering
            try:
                _, poly_loss = self.refine_head(snake_input, None, targets[1])
            except Exception:
                poly_loss = {}

            edge_loss = self.loss(pred_edge_full,
                                  edge_target) * self.loss_weight
            poly_loss.update({
                "loss_edge_det": edge_loss,
            })
            return [], poly_loss, []
        else:

            snake_input = torch.cat([att_map, x], dim=1)

            if "instance" in self.gt_input:
                assert targets[1][0] is not None

                for im_i in range(len(targets[1][0])):
                    gt_instances_per_im = targets[1][0][im_i]
                    bboxes = gt_instances_per_im.gt_boxes.tensor
                    instances_per_im = Instances(
                        pred_instances[im_i]._image_size)
                    instances_per_im.pred_boxes = Boxes(bboxes)
                    instances_per_im.pred_classes = gt_instances_per_im.gt_classes
                    instances_per_im.scores = torch.ones_like(
                        gt_instances_per_im.gt_classes, device=bboxes.device)
                    if gt_instances_per_im.has("gt_masks"):
                        gt_masks = gt_instances_per_im.gt_masks
                        ext_pts_off = self.refine_head.get_simple_extreme_points(
                            gt_masks.polygons).to(bboxes.device)
                        ex_t = torch.stack(
                            [ext_pts_off[:, None, 0], bboxes[:, None, 1]],
                            dim=2)
                        ex_l = torch.stack(
                            [bboxes[:, None, 0], ext_pts_off[:, None, 1]],
                            dim=2)
                        ex_b = torch.stack(
                            [ext_pts_off[:, None, 2], bboxes[:, None, 3]],
                            dim=2)
                        ex_r = torch.stack(
                            [bboxes[:, None, 2], ext_pts_off[:, None, 3]],
                            dim=2)
                        instances_per_im.ext_points = ExtremePoints(
                            torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1))

                    pred_instances[im_i] = instances_per_im

            new_instances, _ = self.refine_head(snake_input, pred_instances,
                                                None)

            pred_edge = att_map

            return pred_edge, {}, new_instances
Beispiel #4
0
    def forward(self, features, pred_instances=None, targets=None):
        if self.training:
            training_targets = self.compute_targets_for_polys(targets)
            locations, reg_targets, scales, image_idx = (
                training_targets["octagon_locs"],
                training_targets["octagon_targets"],
                training_targets["scales"],
                training_targets["image_idx"],
            )
            init_locations, init_targets = (
                training_targets["quadrangle_locs"],
                training_targets["quadrangle_targets"],
            )

        else:
            assert pred_instances is not None
            init_locations, image_idx = self.sample_quadrangles_fast(
                pred_instances)
            if len(init_locations) == 0:
                return pred_instances, {}

        # enhance bottom features TODO: maybe reduce later
        for i in range(self.num_convs):
            features = self.bottom_out[i](features)

        pred_exts = self.init(self.init_snake, features, init_locations,
                              image_idx)

        if not self.training:
            h = features.shape[2] * 4
            w = features.shape[3] * 4

            poly_sample_locations = []
            for i, instance_per_im in enumerate(pred_instances):
                pred_exts_per_im = pred_exts[image_idx == i]  # N x 4 x 2
                pred_exts_per_im[..., 0] = torch.clamp(pred_exts_per_im[...,
                                                                        0],
                                                       min=0,
                                                       max=w - 1)
                pred_exts_per_im[..., 1] = torch.clamp(pred_exts_per_im[...,
                                                                        1],
                                                       min=0,
                                                       max=h - 1)
                if not instance_per_im.has("ext_points"):
                    instance_per_im.ext_points = ExtremePoints(
                        pred_exts_per_im)
                    poly_sample_locations.append(
                        self.get_octagon(pred_exts_per_im, self.num_sampling))
                else:  # NOTE: For GT Input testing
                    # print('Using GT EX')
                    poly_sample_locations.append(
                        self.get_octagon(instance_per_im.ext_points.tensor,
                                         self.num_sampling))
            locations = cat(poly_sample_locations, dim=0)

        location_preds = []

        for i in range(len(self.num_iter)):
            deformer = self.__getattr__("deformer" + str(i))
            if i == 0:
                pred_location = self.evolve(deformer, features, locations,
                                            image_idx)
            else:
                pred_location = self.evolve(deformer, features, pred_location,
                                            image_idx)
            location_preds.append(pred_location)

        if self.training:
            evolve_loss = 0
            for pred in location_preds:
                evolve_loss += (self.loss_reg(
                    pred / scales[:, None, None],
                    reg_targets / scales[:, None, None],
                ) / 3)

            init_loss = self.loss_reg(pred_exts / scales[:, None, None],
                                      init_targets / scales[:, None, None])
            losses = {
                "loss_evolve": evolve_loss * self.refine_loss_weight,
                "loss_init": init_loss * self.refine_loss_weight,
            }
            return [], losses
        else:
            new_instances = self.predict_postprocess(pred_instances, locations,
                                                     location_preds, image_idx)
            return new_instances, {}
Beispiel #5
0
    def compute_targets_for_polys(self, targets):
        init_sample_locations = []
        init_sample_targets = []
        poly_sample_locations = []
        poly_sample_targets = []
        image_index = []
        scales = []

        # per image
        for im_i in range(len(targets)):
            targets_per_im = targets[im_i]
            bboxes = targets_per_im.gt_boxes.tensor

            # no gt
            if bboxes.numel() == 0:
                continue

            gt_masks = targets_per_im.gt_masks

            # use this as a scaling
            ws = bboxes[:, 2] - bboxes[:, 0]
            hs = bboxes[:, 3] - bboxes[:, 1]

            quadrangle = (self.get_quadrangle(bboxes).cpu().numpy().reshape(
                -1, 4, 2))  # (k, 4, 2)

            if self.initial == "octagon":
                # [t_H_off, l_V_off, b_H_off, r_V_off]
                ext_pts_off = self.get_simple_extreme_points(
                    gt_masks.polygons).to(bboxes.device)
                ex_t = torch.stack(
                    [ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2)
                ex_l = torch.stack(
                    [bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2)
                ex_b = torch.stack(
                    [ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2)
                ex_r = torch.stack(
                    [bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2)

                # k x 4 x 2
                ext_points = torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1)

                # N x 16 (ccw)
                octagons = (ExtremePoints(
                    torch.cat([ex_t, ex_l, ex_b, ex_r],
                              dim=1)).get_octagons().cpu().numpy().reshape(
                                  -1, 8, 2))
            else:
                raise ValueError("Invalid initial input!")

            # List[nd.array], element shape: (P, 2) OR None
            contours = self.get_simple_contour(gt_masks)

            # per instance
            for (quad, oct, cnt, ext, w, h) in zip(quadrangle, octagons,
                                                   contours, ext_points, ws,
                                                   hs):
                if cnt is None:
                    continue

                # used for normalization
                scale = torch.min(w, h)

                # make it clock-wise
                cnt = cnt[::-1] if Polygon(cnt).exterior.is_ccw else cnt
                assert not Polygon(
                    cnt).exterior.is_ccw, "1) contour must be clock-wise!"

                # sampling from quadrangle
                # print(quad.shape)
                # print(oct.shape)
                quad_sampled_pts = self.uniform_sample(quad, 40)

                # sampling from octagon
                oct_sampled_pts = self.uniform_sample(oct, self.num_sampling)

                oct_sampled_pts = (oct_sampled_pts[::-1]
                                   if Polygon(oct_sampled_pts).exterior.is_ccw
                                   else oct_sampled_pts)
                assert not Polygon(
                    oct_sampled_pts
                ).exterior.is_ccw, "1) contour must be clock-wise!"

                # sampling from ground truth
                oct_sampled_targets = self.uniform_sample(
                    cnt,
                    len(cnt) * self.num_sampling)  # (big, 2)
                # i) find a single nearest, so that becomes ordered point sets

                tt_idx = np.argmin(
                    np.power(oct_sampled_targets - oct_sampled_pts[0],
                             2).sum(axis=1))
                oct_sampled_targets = np.roll(oct_sampled_targets,
                                              -tt_idx,
                                              axis=0)[::len(cnt)]

                # assert not Polygon(oct_sampled_targets).exterior.is_ccw, '2) contour must be clock-wise!'

                quad_sampled_pts = torch.tensor(quad_sampled_pts,
                                                device=bboxes.device)
                oct_sampled_pts = torch.tensor(oct_sampled_pts,
                                               device=bboxes.device)
                oct_sampled_targets = torch.tensor(oct_sampled_targets,
                                                   device=bboxes.device)

                # oct_sampled_targets = gt_sampled_pts - oct_sampled_pts  # offset field

                init_sample_locations.append(quad_sampled_pts)
                init_sample_targets.append(ext)
                poly_sample_locations.append(oct_sampled_pts)
                poly_sample_targets.append(oct_sampled_targets)
                image_index.append(im_i)
                scales.append(scale)

        init_sample_locations = torch.stack(init_sample_locations, dim=0)
        init_sample_targets = torch.stack(init_sample_targets, dim=0)
        poly_sample_locations = torch.stack(poly_sample_locations, dim=0)
        poly_sample_targets = torch.stack(poly_sample_targets, dim=0)
        image_index = torch.tensor(image_index, device=bboxes.device)
        scales = torch.stack(scales, dim=0)
        return {
            "quadrangle_locs": init_sample_locations,
            "quadrangle_targets": init_sample_targets,
            "octagon_locs": poly_sample_locations,
            "octagon_targets": poly_sample_targets,
            "scales": scales,
            "image_idx": image_index,
        }
Beispiel #6
0
    def forward(self, features, pred_instances=None, targets=None):
        if self.edge_on:
            with timer.env("pfpn_back"):
                for i, f in enumerate(self.in_features):
                    if i == 0:
                        x = self.scale_heads[i](features[f])
                    else:
                        x = x + self.scale_heads[i](features[f])

        if self.edge_on:
            with timer.env("edge"):
                pred_logits = self.predictor(x)
                pred_edge = pred_logits.sigmoid()
                if self.attention:
                    # print('pred edge', pred_edge)
                    att_map = self.attender(
                        1 - pred_edge
                    )  # regions that need evolution

        if self.training:
            edge_target = targets[0]
            if self.edge_in:
                edge_prior = targets[0].unsqueeze(1).float().clone()  # (B, 1, H, W)
                edge_prior[edge_prior == self.ignore_value] = 0  # remove ignore value

                edge_prior = self.mean_filter(edge_prior)
                edge_prior = F.interpolate(
                    edge_prior,
                    scale_factor=1 / self.common_stride,
                    mode="bilinear",
                    align_corners=False,
                )
                edge_prior[edge_prior > 0] = 1

                if self.strong_feat:
                    snake_input = torch.cat([edge_prior, x], dim=1)
                else:
                    snake_input = torch.cat([edge_prior, features["p2"]], dim=1)
            else:
                if self.strong_feat:
                    snake_input = x
                else:
                    snake_input = features["p2"]

            if self.edge_on:
                pred_edge_full = F.interpolate(
                    pred_edge,
                    scale_factor=self.common_stride,
                    mode="bilinear",
                    align_corners=False,
                )

            if self.selective_refine:
                edge_prior = targets[0].unsqueeze(1).float().clone()  # (B, 1, H, W)
                edge_prior[edge_prior == self.ignore_value] = 0  # remove ignore value
                edge_prior = self.dilate_filter(edge_prior)
                # edge_prior = self.dilate_filter(edge_prior)
                # edge_target = edge_prior.clone()
                edge_prior[edge_prior > 0] = 1
                edge_prior = F.interpolate(
                    edge_prior,
                    scale_factor=1 / self.common_stride,
                    mode="bilinear",
                    align_corners=False,
                )
                if self.strong_feat:
                    snake_input = torch.cat([edge_prior, x], dim=1)
                else:
                    if self.pred_edge:
                        snake_input = torch.cat(
                            [edge_prior, pred_logits, features["p2"]], dim=1
                        )
                    else:
                        snake_input = torch.cat([edge_prior, features["p2"]], dim=1)

            if self.attention:
                if self.strong_feat:
                    snake_input = torch.cat([att_map, x], dim=1)
                else:
                    # dont cater pred_edge option now
                    snake_input = torch.cat([att_map, features["p2"]], dim=1)

            ### Quick fix for batches that do not have poly after filtering
            _, poly_loss = self.refine_head(snake_input, None, targets[1])

            if self.edge_on:
                edge_loss = self.loss(pred_edge_full, edge_target) * self.loss_weight
                poly_loss.update(
                    {
                        "loss_edge_det": edge_loss,
                    }
                )

            return [], poly_loss, []
        else:
            if self.edge_in or self.selective_refine:
                if self.edge_map_thre > 0:
                    pred_edge = (pred_edge > self.edge_map_thre).float()

                if "edge" in self.gt_input:
                    assert targets[0] is not None
                    pred_edge = targets[0].unsqueeze(1).float().clone()
                    pred_edge[pred_edge == self.ignore_value] = 0  # remove ignore value

                    if self.selective_refine:
                        pred_edge = self.dilate_filter(pred_edge)
                        # pred_edge = self.dilate_filter(pred_edge)

                    pred_edge = F.interpolate(
                        pred_edge,
                        scale_factor=1 / self.common_stride,
                        mode="bilinear",
                        align_corners=False,
                    )

                    pred_edge[pred_edge > 0] = 1
                if self.strong_feat:
                    snake_input = torch.cat([pred_edge, x], dim=1)
                else:
                    snake_input = torch.cat([pred_edge, features["p2"]], dim=1)
            else:
                if self.strong_feat:
                    snake_input = x
                else:
                    snake_input = features["p2"]

            if self.attention:
                if self.strong_feat:
                    snake_input = torch.cat([att_map, x], dim=1)
                else:
                    # dont cater pred_edge option now
                    snake_input = torch.cat([att_map, features["p2"]], dim=1)

            if "instance" in self.gt_input:
                assert targets[1][0] is not None

                for im_i in range(len(targets[1][0])):
                    gt_instances_per_im = targets[1][0][im_i]
                    bboxes = gt_instances_per_im.gt_boxes.tensor
                    instances_per_im = Instances(pred_instances[im_i]._image_size)
                    instances_per_im.pred_boxes = Boxes(bboxes)
                    instances_per_im.pred_classes = gt_instances_per_im.gt_classes
                    instances_per_im.scores = torch.ones_like(
                        gt_instances_per_im.gt_classes, device=bboxes.device
                    )
                    if gt_instances_per_im.has("gt_masks"):
                        gt_masks = gt_instances_per_im.gt_masks
                        ext_pts_off = self.refine_head.get_simple_extreme_points(
                            gt_masks.polygons
                        ).to(bboxes.device)
                        ex_t = torch.stack(
                            [ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2
                        )
                        ex_l = torch.stack(
                            [bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2
                        )
                        ex_b = torch.stack(
                            [ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2
                        )
                        ex_r = torch.stack(
                            [bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2
                        )
                        instances_per_im.ext_points = ExtremePoints(
                            torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1)
                        )

                        # TODO: NOTE: Test for theoretic limit. #####
                        # contours = self.refine_head.get_simple_contour(gt_masks)
                        # poly_sample_targets = []
                        # for i, cnt in enumerate(contours):
                        #     if cnt is None:
                        #         xmin, ymin = bboxes[:, 0], bboxes[:, 1]  # (n,)
                        #         xmax, ymax = bboxes[:, 2], bboxes[:, 3]  # (n,)
                        #         box = [
                        #             xmax, ymin, xmin, ymin, xmin, ymax, xmax, ymax
                        #         ]
                        #         box = torch.stack(box, dim=1).view(-1, 4, 2)
                        #         sampled_box = self.refine_head.uniform_upsample(box[None],
                        #                                                         self.refine_head.num_sampling)
                        #         poly_sample_targets.append(sampled_box[i])
                        #         # print(sampled_box.shape)
                        #         continue
                        #
                        #     # 1) uniform-sample
                        #     oct_sampled_targets = self.refine_head.uniform_sample(cnt,
                        #                                                           len(cnt) * self.refine_head.num_sampling)  # (big, 2)
                        #     tt_idx = np.random.randint(len(oct_sampled_targets))
                        #     oct_sampled_targets = np.roll(oct_sampled_targets, -tt_idx, axis=0)[::len(cnt)]
                        #     oct_sampled_targets = torch.tensor(oct_sampled_targets, device=bboxes.device)
                        #     poly_sample_targets.append(oct_sampled_targets)
                        #     # print(oct_sampled_targets.shape)
                        #
                        #     # 2) polar-sample
                        #     # ...
                        # poly_sample_targets = torch.stack(poly_sample_targets, dim=0)
                        # instances_per_im.pred_polys = PolygonPoints(poly_sample_targets)
                        # TODO: NOTE: Test for theoretic limit. #####

                    pred_instances[im_i] = instances_per_im

            new_instances, _ = self.refine_head(snake_input, pred_instances, None)
            # new_instances = pred_instances
            if not self.edge_on:
                pred_edge = torch.rand(1, 1, 5, 5, device=snake_input.device)

            if self.attention:
                pred_edge = att_map

            return pred_edge, {}, new_instances