コード例 #1
0
    def regress_by_class(self, rois, label, bbox_pred, img_meta):
        """Regress the bbox for the predicted class. Used in Cascade R-CNN.

        Args:
            rois (Tensor): shape (n, 4) or (n, 5)
            label (Tensor): shape (n, )
            bbox_pred (Tensor): shape (n, 4*(#class+1)) or (n, 4)
            img_meta (dict): Image meta info.

        Returns:
            Tensor: Regressed bboxes, the same shape as input rois.
        """
        breakpoint()
        assert rois.size(1) == 4 or rois.size(1) == 5

        if not self.reg_class_agnostic:
            label = label * 4
            inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
            bbox_pred = torch.gather(bbox_pred, 1, inds)
        assert bbox_pred.size(1) == 4

        if rois.size(1) == 4:
            new_rois = delta2bbox3D(rois, bbox_pred, self.target_means,
                                    self.target_stds, img_meta['img_shape'])
        else:
            bboxes = delta2bbox3D(rois[:, 1:], bbox_pred, self.target_means,
                                  self.target_stds, img_meta['img_shape'])
            new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)

        return new_rois
コード例 #2
0
    def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))
        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        if bbox_pred is not None:
            bboxes = delta2bbox3D(rois[:, 1:], bbox_pred, self.target_means,
                                  self.target_stds, img_shape)
        else:
            bboxes = rois[:, 1:]
            # TODO: add clip here

        if rescale:
            bboxes /= scale_factor
            # TODO: we need to change this.... only scale x,y but not z
            # boxes_tmp = bboxes[:, :4] / scale_factor
            # boxes_tmp_2 = bboxes[:, 6:10] / scale_factor
            # bboxes = torch.cat((boxes_tmp, bboxes[:, 4:6], boxes_tmp_2, bboxes[:, 10:12]), 1)

        if cfg is None:
            return bboxes, scores
        else:
            det_bboxes, det_labels = multiclass_nms_3d(bboxes, scores,
                                                       cfg.score_thr, cfg.nms,
                                                       cfg.max_per_img)

            return det_bboxes, det_labels
コード例 #3
0
    def get_det_bboxes(self,
                       rois,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if bbox_pred is not None:
            bboxes = delta2bbox3D(rois[:, 1:], bbox_pred, self.target_means,
                                  self.target_stds, img_shape)
        else:
            bboxes = rois[:, 1:]

        if rescale:
            bboxes /= scale_factor

        if cfg is None:
            return bboxes
        else:
            det_bboxes, det_labels = multiclass_nms_3d(bboxes, scores,
                                                       cfg.score_thr, cfg.nms,
                                                       cfg.max_per_img)

            return det_bboxes, det_labels
コード例 #4
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          mlvl_anchors,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        mlvl_proposals = []
        anchors_levels = []
        for idx in range(len(cls_scores)):
            rpn_cls_score = cls_scores[idx]
            rpn_bbox_pred = bbox_preds[idx]
            assert rpn_cls_score.size()[-3:] == rpn_bbox_pred.size()[-3:]
            anchors = mlvl_anchors[idx]
            rpn_cls_score = rpn_cls_score.permute(2, 3, 1, 0)
            if self.use_sigmoid_cls:
                rpn_cls_score = rpn_cls_score.reshape(-1)
                scores = rpn_cls_score.sigmoid()
            else:
                rpn_cls_score = rpn_cls_score.reshape(-1, 2)
                scores = rpn_cls_score.softmax(dim=1)[:, 1]
            rpn_bbox_pred = rpn_bbox_pred.permute(2, 3, 1, 0).reshape(-1, 6)
            if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
                # filter out all the negative anchors
                if self.pos_indices is not None and self.pos_indices[
                        idx].shape == scores.shape:
                    pos_indices = self.pos_indices[idx]
                    scores = scores[pos_indices]
                    rpn_bbox_pred = rpn_bbox_pred[pos_indices]
                    anchors = anchors[pos_indices]
                elif self.pos_indices_test is not None and self.pos_indices_test[
                        idx].shape == scores.shape:
                    pos_indices = self.pos_indices_test[idx]
                    scores = scores[pos_indices]
                    rpn_bbox_pred = rpn_bbox_pred[pos_indices]
                    anchors = anchors[pos_indices]

                if scores.shape[0] > cfg.nms_pre:
                    _, topk_inds = scores.topk(cfg.nms_pre)
                    rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
                    anchors = anchors[topk_inds, :]
                    scores = scores[topk_inds]

                # debug only...
                # out = open('output.json', 'a+')
                # out.write("best anchors.......:\n")
                # out.write("topk_inds: {}\n".format(topk_inds))
                # out.write("anchors: {}\n".format(anchors))
                # out.write("scores: {}\n".format(scores))
                # out.write("num anchors and scores: {}\n".format(len(anchors)))
                # out.write("\n\n")

            proposals = delta2bbox3D(anchors, rpn_bbox_pred, self.target_means,
                                     self.target_stds, img_shape)
            if cfg.min_bbox_size > 0:
                breakpoint()
                w = proposals[:, 2] - proposals[:, 0] + 1
                h = proposals[:, 3] - proposals[:, 1] + 1
                valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
                                           (h >= cfg.min_bbox_size)).squeeze()
                proposals = proposals[valid_inds, :]
                scores = scores[valid_inds]
            proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
            proposals, _ = nms(proposals, cfg.nms_thr)
            proposals = proposals[:cfg.nms_post, :]
            mlvl_proposals.append(proposals)
            anchors_levels.append(anchors)
        anchors_levels = torch.cat(anchors_levels, 0)
        proposals = torch.cat(mlvl_proposals, 0)
        if cfg.nms_across_levels:
            proposals, _ = nms(proposals, cfg.nms_thr)
            proposals = proposals[:cfg.max_num, :]
        else:
            scores = proposals[:, 6]
            num = min(cfg.max_num, proposals.shape[0])
            # topk_inds = scores > 0.1 # RPN soft cutoff
            _, topk_inds = scores.topk(num)  # original code
            proposals = proposals[topk_inds, :]
        return proposals, anchors_levels
コード例 #5
0
 def convert_adjustments_to_bboxes(self, rois, bbox_pred, img_shape):
     return delta2bbox3D(rois[:, 1:], bbox_pred, self.target_means,
                         self.target_stds, img_shape)