Ejemplo n.º 1
0
    def simple_test(self, img, img_meta, rescale=False):
        if 'tile_offset' in img_meta[0]:
            # using tile-cropped TTA. force using aug_test instead of simple_test
            return self.aug_test(imgs=[img],
                                 img_metas=[img_meta],
                                 rescale=True)

        x = self.extract_feat(img)
        outs = self.bbox_head(x)
        rois = self.bbox_head.filter_bboxes(*outs)
        # rois: list(indexed by images) of list(indexed by levels)
        for i in range(self.num_refine_stages):
            x_refine = self.feat_refine_module[i](x, rois)
            outs = self.refine_head[i](x_refine)
            if i + 1 in range(self.num_refine_stages):
                rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)

        bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
        bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois)
        bbox_results = [
            rbbox2result(det_bboxes, det_labels,
                         self.refine_head[-1].num_classes)
            for det_bboxes, det_labels in bbox_list
        ]
        return bbox_results[0]
    def aug_test(self, imgs, img_metas, rescale=True):
        AUG_BS = 8
        assert rescale, '''while r3det uses overlapped cropping augmentation by default,
        the result should be rescaled to input images sizes to simplify the test pipeline'''
        if 'tile_offset' in img_metas[0][0]:
            assert imgs[0].size(
                0) == 1, '''when using cropped tiles augmentation,
            image batch size must be set to 1'''
            aug_det_bboxes, aug_det_labels = [], []
            num_augs = len(imgs)
            for idx in range(0, num_augs, AUG_BS):
                img = imgs[idx:idx + AUG_BS]
                img_meta = img_metas[idx:idx + AUG_BS]
                act_num_augs = len(img_meta)
                img = torch.cat(img, dim=0)
                img_meta = sum(img_meta, [])
                # for img, img_meta in zip(imgs, img_metas):
                x = self.extract_feat(img)
                outs = self.bbox_head(x)
                rois = self.bbox_head.filter_bboxes(*outs)
                # rois: list(indexed by images) of list(indexed by levels)
                det_bbox_bs = [[] for _ in range(act_num_augs)]
                det_label_bs = [[] for _ in range(act_num_augs)]
                for i in range(self.num_refine_stages):
                    #                     x_refine = self.feat_refine_module[i](x, rois)
                    x_refine = x
                    outs = self.refine_head[i](x_refine)
                    if i + 1 in range(self.num_refine_stages):
                        rois = self.refine_head[i].refine_bboxes(*outs,
                                                                 rois=rois)

                    bbox_inputs = outs + (img_meta, self.test_cfg, False)
                    bbox_bs = self.refine_head[i].get_bboxes(*bbox_inputs,
                                                             rois=rois)
                    # [(rbbox_aug0, class_aug0), (rbbox_aug1, class_aug1), (rbbox_aug2, class_aug2), ...]
                    for j in range(act_num_augs):
                        det_bbox_bs[j].append(bbox_bs[j][0])
                        det_label_bs[j].append(bbox_bs[j][1])

                for j in range(act_num_augs):
                    det_bbox_bs[j] = torch.cat(det_bbox_bs[j])
                    det_label_bs[j] = torch.cat(det_label_bs[j])

                aug_det_bboxes += det_bbox_bs
                aug_det_labels += det_label_bs

            aug_det_bboxes, aug_det_labels = merge_tiles_aug_rbboxes(
                aug_det_bboxes, aug_det_labels, img_metas,
                self.test_cfg.merge_cfg, self.CLASSES)

            return rbbox2result(aug_det_bboxes, aug_det_labels,
                                self.refine_head[-1].num_classes)

        else:
            raise NotImplementedError
Ejemplo n.º 3
0
    def simple_test(self, img, img_meta, proposals=None, rescale=False):
        """Test without augmentation."""
        assert self.with_bbox, "Bbox head must be implemented."

        x = self.extract_feat(img)

        proposal_rotate_list = self.simple_test_rpn(
            x, img_meta, self.test_cfg.rpn) if proposals is None else proposals

        if self.with_bbox:
            rois = rbboxPoly2rroiRec(proposal_rotate_list)
            bbox_cls_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            bbox_reg_feats = bbox_cls_feats
            if self.with_shared_head:
                bbox_cls_feats = self.shared_head(bbox_cls_feats)
                bbox_reg_feats = self.shared_head(bbox_reg_feats)
            cls_score, bbox_xy_pred, bbox_wh_pred, bbox_theta_pred = self.bbox_head(
                bbox_cls_feats, bbox_reg_feats)

            img_shape = img_meta[0]['ori_shape']
            scale_factor = img_meta[0]['scale_factor']
            det_bboxes, det_labels = self.bbox_head.get_det_rbbox2rbbox(
                rois,
                cls_score,
                bbox_xy_pred,
                bbox_wh_pred,
                bbox_theta_pred,
                img_shape,
                scale_factor,
                rescale=True,
                cfg=self.test_cfg.rcnn)
            rbbox_results = rbbox2result(det_bboxes, det_labels,
                                         self.bbox_head.num_classes)

        if not self.with_mask:
            return rbbox_results
        else:
            return
Ejemplo n.º 4
0
    def aug_test(self, imgs, img_metas, rescale=False):
        """Test with augmentations.

        If rescale is False, then returned bboxes and masks will fit the scale
        of imgs[0].
        """
        # recompute feats to save memory
        proposal_list = self.aug_test_rotate_rpn(self.extract_feats(imgs),
                                                 img_metas, self.test_cfg.rpn)

        aug_bboxes = []
        aug_scores = []
        for x, img_meta in zip(self.extract_feats(imgs), img_metas):
            # only one image in the batch
            img_shape = img_meta[0]['img_shape']
            scale_factor = img_meta[0]['scale_factor']
            flip = img_meta[0]['flip']
            # TODO more flexible
            proposals = [
                bbox_mapping(proposal_list[0][:, :8], img_shape, scale_factor,
                             flip)
            ]
            rois = rbboxPoly2rroiRec(proposals)
            bbox_cls_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            bbox_reg_feats = bbox_cls_feats
            if self.with_shared_head:
                bbox_cls_feats = self.shared_head(bbox_cls_feats)
                bbox_reg_feats = self.shared_head(bbox_reg_feats)
            cls_score, bbox_xy_pred, bbox_wh_pred, bbox_theta_pred = self.bbox_head(
                bbox_cls_feats, bbox_reg_feats)

            img_shape = img_meta[0]['img_shape']
            scale_factor = img_meta[0]['scale_factor']
            bboxes, scores = self.bbox_head.get_det_rbbox2rbbox(
                rois,
                cls_score,
                bbox_xy_pred,
                bbox_wh_pred,
                bbox_theta_pred,
                img_shape,
                scale_factor,
                rescale=False,
                cfg=None)
            aug_bboxes.append(bboxes)
            aug_scores.append(scores)
        merged_bboxes, merged_scores = merge_aug_rotate_bboxes(
            aug_bboxes, aug_scores, img_metas, self.test_cfg.rcnn)
        det_bboxes, det_labels = multiclass_poly_nms_8_points(
            merged_bboxes,
            merged_scores,
            self.test_cfg.rcnn.score_thr,
            self.test_cfg.rcnn.nms,
            max_num=self.test_cfg.rcnn.max_per_img)

        bbox_results = rbbox2result(det_bboxes, det_labels,
                                    self.bbox_head.num_classes)

        # det_bboxes always keep the original scale
        if self.with_mask:
            pass
        else:
            return bbox_results