def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))
        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        # TODO: check and simplify it
        if rois.size(1) == 5:
            obbs = hbb2obb_v2(rois[:, 1:])
        elif rois.size(1) == 6:
            obbs = rois[:, 1:]
        else:
            print('strange size')
            import pdb
            pdb.set_trace()
        if bbox_pred is not None:
            # bboxes = delta2dbbox(rois[:, 1:], bbox_pred, self.target_means,
            #                     self.target_stds, img_shape)
            if self.with_module:
                dbboxes = delta2dbbox(obbs, bbox_pred, self.target_means,
                                      self.target_stds, img_shape)
            else:
                dbboxes = delta2dbbox_v3(obbs, bbox_pred, self.target_means,
                                         self.target_stds, img_shape)
        else:
            # bboxes = rois[:, 1:]
            dbboxes = obbs
            # TODO: add clip here

        if rescale:
            # bboxes /= scale_factor
            # dbboxes[:, :4] /= scale_factor
            dbboxes[:, 0::5] /= scale_factor
            dbboxes[:, 1::5] /= scale_factor
            dbboxes[:, 2::5] /= scale_factor
            dbboxes[:, 3::5] /= scale_factor
        # if cfg is None:
        #     c_device = dbboxes.device
        #
        #     det_bboxes, det_labels = Pesudomulticlass_nms_rbbox(dbboxes, scores,
        #                                             0.05,
        #                                             1000)
        #
        #     return det_bboxes, det_labels
        # else:
        c_device = dbboxes.device

        det_bboxes, det_labels = multiclass_nms_rbbox(dbboxes, scores,
                                                      cfg.score_thr, cfg.nms,
                                                      cfg.max_per_img)
        # det_bboxes = torch.from_numpy(det_bboxes).to(c_device)
        # det_labels = torch.from_numpy(det_labels).to(c_device)
        return det_bboxes, det_labels
Exemple #2
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          mlvl_anchors,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
        mlvl_bboxes = []
        mlvl_scores = []
        for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
                                                 mlvl_anchors):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            cls_score = cls_score.permute(1, 2,
                                          0).reshape(-1, self.cls_out_channels)
            if self.use_sigmoid_cls:
                scores = cls_score.sigmoid()
            else:
                scores = cls_score.softmax(-1)
            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                if self.use_sigmoid_cls:
                    max_scores, _ = scores.max(dim=1)
                else:
                    max_scores, _ = scores[:, 1:].max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                anchors = anchors[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]

            rbbox_ex_anchors = hbb2obb_v2(anchors)
            if self.with_module:
                bboxes = delta2dbbox(rbbox_ex_anchors, bbox_pred,
                                     self.target_means, self.target_stds,
                                     img_shape)
            else:
                bboxes = delta2dbbox_v3(rbbox_ex_anchors, bbox_pred,
                                        self.target_means, self.target_stds,
                                        img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            # mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
            mlvl_bboxes[:, :4] /= mlvl_bboxes[:, :4].new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)
        if self.use_sigmoid_cls:
            padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
            mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
        # det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
        #                                         cfg.score_thr, cfg.nms,
        #                                         cfg.max_per_img)
        det_bboxes, det_labels = multiclass_nms_rbbox(mlvl_bboxes, mlvl_scores,
                                                      cfg.score_thr, cfg.nms,
                                                      cfg.max_per_img)
        return det_bboxes, det_labels
Exemple #3
0
    def regress_by_class_rbbox(self, rois, label, bbox_pred, img_meta):
        """Regress the bbox for the predicted class. Used in Cascade R-CNN.

        Args:
            rois (Tensor): shape (n, 5) or (n, 6)
            label (Tensor): shape (n, )
            bbox_pred (Tensor): shape (n, 5*(#class+1)) or (n, 5)
            img_meta (dict): Image meta info.

        Returns:
            Tensor: Regressed bboxes, the same shape as input rois.
        """
        # import pdb
        # pdb.set_trace()
        assert rois.size(1) == 5 or rois.size(1) == 6

        if not self.reg_class_agnostic:
            label = label * 5
            inds = torch.stack(
                (label, label + 1, label + 2, label + 3, label + 4), 1)
            bbox_pred = torch.gather(bbox_pred, 1, inds)
        assert bbox_pred.size(1) == 5

        if rois.size(1) == 5:
            if self.with_module:
                new_rois = delta2dbbox(rois, bbox_pred, self.target_means,
                                       self.target_stds, img_meta['img_shape'])
            else:
                new_rois = delta2dbbox_v3(rois, bbox_pred, self.target_means,
                                          self.target_stds,
                                          img_meta['img_shape'])
            # choose best Rroi
            new_rois = choose_best_Rroi_batch(new_rois)
        else:
            if self.with_module:
                bboxes = delta2dbbox(rois[:, 1:], bbox_pred, self.target_means,
                                     self.target_stds, img_meta['img_shape'])
            else:
                bboxes = delta2dbbox_v3(rois[:, 1:], bbox_pred,
                                        self.target_means, self.target_stds,
                                        img_meta['img_shape'])
            bboxes = choose_best_Rroi_batch(bboxes)
            new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)

        return new_rois
Exemple #4
0
 def get_ori_rbboxes(self, rois, delta_rbboxes, img_shape, v2=False):
     if v2:
         ori_rbboxes = delta2dbbox_v2(rois[:, 1:], delta_rbboxes,
                                      self.target_means, self.target_stds,
                                      img_shape)
     else:
         ori_rbboxes = delta2dbbox_v3(rois[:, 1:], delta_rbboxes,
                                      self.target_means, self.target_stds,
                                      img_shape)
     return ori_rbboxes
Exemple #5
0
    def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))
        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        # TODO: check and simplify it
        if rois.size(1) == 5:
            obbs = hbb2obb_v2(rois[:, 1:])
        elif rois.size(1) == 6:
            obbs = rois[:, 1:]
        else:
            print('strange size')
            import pdb
            pdb.set_trace()
        if bbox_pred is not None:
            # bboxes = delta2dbbox(rois[:, 1:], bbox_pred, self.target_means,
            #                     self.target_stds, img_shape)
            if self.with_module:
                dbboxes = delta2dbbox(obbs, bbox_pred, self.target_means,
                                      self.target_stds, img_shape)
            else:
                dbboxes = delta2dbbox_v3(obbs, bbox_pred, self.target_means,
                                         self.target_stds, img_shape)
        else:
            # bboxes = rois[:, 1:]
            dbboxes = obbs
            # TODO: add clip here

        if rescale:
            # bboxes /= scale_factor
            # dbboxes[:, :4] /= scale_factor
            dbboxes[:, 0::5] /= scale_factor
            dbboxes[:, 1::5] /= scale_factor
            dbboxes[:, 2::5] /= scale_factor
            dbboxes[:, 3::5] /= scale_factor
        # if cfg is None:
        #     c_device = dbboxes.device
        #
        #     det_bboxes, det_labels = Pesudomulticlass_nms_rbbox(dbboxes, scores,
        #                                             0.05,
        #                                             1000)
        #
        #     return det_bboxes, det_labels
        # else:
        c_device = dbboxes.device

        det_bboxes, det_labels = multiclass_nms_rbbox(dbboxes, scores,
                                                cfg.score_thr, cfg.nms,
                                                cfg.max_per_img)
        # det_bboxes = torch.from_numpy(det_bboxes).to(c_device)
        # det_labels = torch.from_numpy(det_labels).to(c_device)

        # ###########################################################
        from mmdet.MARK import PRINT_RBBOX_HEAD_RS_LOSS
        if PRINT_RBBOX_HEAD_RS_LOSS:
            # pos_inds = pos_inds[0:min(len(pos_inds), 10)]
            pred_score = scores
            pred_score, pred_label = torch.max(pred_score, dim=1)

            # 前景标签
            pred_f_indices = pred_label != 0

            if torch.sum(pred_f_indices) > 0:
                print('#' * 80)
                print('for pred score: ', pred_score[pred_f_indices])
                print('#' * 80)
        # ###########################################################

        return det_bboxes, det_labels