Exemple #1
0
    def regress_by_class(self, rois, label, bbox_pred, img_meta):
        """Regress the bbox for the predicted class. Used in Cascade R-CNN.

        Args:
            rois (Tensor): shape (n, 4) or (n, 5)
            label (Tensor): shape (n, )
            bbox_pred (Tensor): shape (n, 4*(#class+1)) or (n, 4)
            img_meta (dict): Image meta info.

        Returns:
            Tensor: Regressed bboxes, the same shape as input rois.
        """
        assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)

        if not self.reg_class_agnostic:
            label = label * 4
            inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
            bbox_pred = torch.gather(bbox_pred, 1, inds)
        assert bbox_pred.size(1) == 4

        if rois.size(1) == 4:
            new_rois = delta2bbox(rois, bbox_pred, self.target_means,
                                  self.target_stds, img_meta['img_shape'])
        else:
            bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
                                self.target_stds, img_meta['img_shape'])
            new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)

        return new_rois
Exemple #2
0
    def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
                    label_weights, bbox_targets, num_total_samples, cfg):

        anchors = anchors.reshape(-1, 4)
        cls_score = cls_score.permute(0, 2, 3,
                                      1).reshape(-1, self.cls_out_channels)
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
        centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
        bbox_targets = bbox_targets.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)

        # classification loss
        loss_cls = self.loss_cls(cls_score,
                                 labels,
                                 label_weights,
                                 avg_factor=num_total_samples)

        pos_inds = torch.nonzero(labels).squeeze(1)

        if len(pos_inds) > 0:
            pos_bbox_targets = bbox_targets[pos_inds]
            pos_bbox_pred = bbox_pred[pos_inds]
            pos_anchors = anchors[pos_inds]
            pos_centerness = centerness[pos_inds]

            centerness_targets = self.centerness_target(
                pos_anchors, pos_bbox_targets)
            pos_decode_bbox_pred = delta2bbox(pos_anchors, pos_bbox_pred,
                                              self.target_means,
                                              self.target_stds)
            pos_decode_bbox_targets = delta2bbox(pos_anchors, pos_bbox_targets,
                                                 self.target_means,
                                                 self.target_stds)

            # regression loss
            loss_bbox = self.loss_bbox(pos_decode_bbox_pred,
                                       pos_decode_bbox_targets,
                                       weight=centerness_targets,
                                       avg_factor=1.0)

            # centerness loss
            loss_centerness = self.loss_centerness(
                pos_centerness,
                centerness_targets,
                avg_factor=num_total_samples)

        else:
            loss_bbox = bbox_pred.sum() * 0
            loss_centerness = centerness.sum() * 0
            centerness_targets = torch.tensor(0).cuda()

        return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
Exemple #3
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          centernesses,
                          mlvl_anchors,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        for cls_score, bbox_pred, centerness, anchors in zip(
                cls_scores, bbox_preds, centernesses, mlvl_anchors):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]

            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                anchors = anchors[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]

            bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
                                self.target_stds, img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)

        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)

        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)

        det_bboxes, det_labels = multiclass_nms(mlvl_bboxes,
                                                mlvl_scores,
                                                cfg.score_thr,
                                                cfg.nms,
                                                cfg.max_per_img,
                                                score_factors=mlvl_centerness)
        return det_bboxes, det_labels
Exemple #4
0
 def get_bboxes_single(self,
                       cls_score_list,
                       bbox_pred_list,
                       mlvl_anchors,
                       img_shape,
                       scale_factor,
                       cfg,
                       rescale=False):
     """
     Transform outputs for a single batch item into labeled boxes.
     """
     assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
     mlvl_bboxes = []
     mlvl_scores = []
     for cls_score, bbox_pred, anchors in zip(cls_score_list,
                                              bbox_pred_list, mlvl_anchors):
         assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
         cls_score = cls_score.permute(1, 2,
                                       0).reshape(-1, self.cls_out_channels)
         if self.use_sigmoid_cls:
             scores = cls_score.sigmoid()
         else:
             scores = cls_score.softmax(-1)
         bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
         nms_pre = cfg.get('nms_pre', -1)
         if nms_pre > 0 and scores.shape[0] > nms_pre:
             # Get maximum scores for foreground classes.
             if self.use_sigmoid_cls:
                 max_scores, _ = scores.max(dim=1)
             else:
                 max_scores, _ = scores[:, 1:].max(dim=1)
             _, topk_inds = max_scores.topk(nms_pre)
             anchors = anchors[topk_inds, :]
             bbox_pred = bbox_pred[topk_inds, :]
             scores = scores[topk_inds, :]
         bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
                             self.target_stds, img_shape)
         mlvl_bboxes.append(bboxes)
         mlvl_scores.append(scores)
     mlvl_bboxes = torch.cat(mlvl_bboxes)
     if rescale:
         mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
     mlvl_scores = torch.cat(mlvl_scores)
     if self.use_sigmoid_cls:
         # Add a dummy background class to the front when using sigmoid
         padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
         mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
     det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
                                             cfg.score_thr, cfg.nms,
                                             cfg.max_per_img)
     return det_bboxes, det_labels
 def get_bboxes_single(self,
                       cls_scores,
                       bbox_preds,
                       mlvl_anchors,
                       img_shape,
                       scale_factor,
                       cfg,
                       rescale=False):
     mlvl_proposals = []
     for idx in range(len(cls_scores)):
         rpn_cls_score = cls_scores[idx]
         rpn_bbox_pred = bbox_preds[idx]
         assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
         rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
         if self.use_sigmoid_cls:
             rpn_cls_score = rpn_cls_score.reshape(-1)
             scores = rpn_cls_score.sigmoid()
         else:
             rpn_cls_score = rpn_cls_score.reshape(-1, 2)
             scores = rpn_cls_score.softmax(dim=1)[:, 1]
         rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
         anchors = mlvl_anchors[idx]
         if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
             _, topk_inds = scores.topk(cfg.nms_pre)
             rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
             anchors = anchors[topk_inds, :]
             scores = scores[topk_inds]
         proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
                                self.target_stds, img_shape)
         if cfg.min_bbox_size > 0:
             w = proposals[:, 2] - proposals[:, 0] + 1
             h = proposals[:, 3] - proposals[:, 1] + 1
             valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
                                        (h >= cfg.min_bbox_size)).squeeze()
             proposals = proposals[valid_inds, :]
             scores = scores[valid_inds]
         proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
         proposals, _ = nms(proposals, cfg.nms_thr)
         proposals = proposals[:cfg.nms_post, :]
         mlvl_proposals.append(proposals)
     proposals = torch.cat(mlvl_proposals, 0)
     if cfg.nms_across_levels:
         proposals, _ = nms(proposals, cfg.nms_thr)
         proposals = proposals[:cfg.max_num, :]
     else:
         scores = proposals[:, 4]
         num = min(cfg.max_num, proposals.shape[0])
         _, topk_inds = scores.topk(num)
         proposals = proposals[topk_inds, :]
     return proposals
Exemple #6
0
    def centerness_target(self, anchors, bbox_targets):
        # only calculate pos centerness targets, otherwise there may be nan
        gts = delta2bbox(anchors, bbox_targets, self.target_means,
                         self.target_stds)
        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
        l_ = anchors_cx - gts[:, 0]
        t_ = anchors_cy - gts[:, 1]
        r_ = gts[:, 2] - anchors_cx
        b_ = gts[:, 3] - anchors_cy

        left_right = torch.stack([l_, r_], dim=1)
        top_bottom = torch.stack([t_, b_], dim=1)
        centerness = torch.sqrt(
            (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
            (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
        assert not torch.isnan(centerness).any()
        return centerness
Exemple #7
0
    def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))
        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        if bbox_pred is not None:
            bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
                                self.target_stds, img_shape)
        else:
            bboxes = rois[:, 1:].clone()
            if img_shape is not None:
                bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
                bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)

        if rescale:
            if isinstance(scale_factor, float):
                bboxes /= scale_factor
            else:
                scale_factor = torch.from_numpy(scale_factor).to(bboxes.device)
                bboxes = (bboxes.view(bboxes.size(0), -1, 4) /
                          scale_factor).view(bboxes.size()[0], -1)

        if cfg is None:
            return bboxes, scores
        else:
            det_bboxes, det_labels = multiclass_nms(bboxes, scores,
                                                    cfg.score_thr, cfg.nms,
                                                    cfg.max_per_img)

            return det_bboxes, det_labels
Exemple #8
0
 def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
                       anchor_weights, anchor_total_num):
     shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2)
     bbox_anchors = bbox_anchors.contiguous().view(-1, 4)
     bbox_gts = bbox_gts.contiguous().view(-1, 4)
     anchor_weights = anchor_weights.contiguous().view(-1, 4)
     bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0)
     bbox_deltas[:, 2:] += shape_pred
     # filter out negative samples to speed-up weighted_bounded_iou_loss
     inds = torch.nonzero(anchor_weights[:, 0] > 0).squeeze(1)
     bbox_deltas_ = bbox_deltas[inds]
     bbox_anchors_ = bbox_anchors[inds]
     bbox_gts_ = bbox_gts[inds]
     anchor_weights_ = anchor_weights[inds]
     pred_anchors_ = delta2bbox(bbox_anchors_,
                                bbox_deltas_,
                                self.anchoring_means,
                                self.anchoring_stds,
                                wh_ratio_clip=1e-6)
     loss_shape = self.loss_shape(pred_anchors_,
                                  bbox_gts_,
                                  anchor_weights_,
                                  avg_factor=anchor_total_num)
     return loss_shape
Exemple #9
0
    def get_guided_anchors_single(self,
                                  squares,
                                  shape_pred,
                                  loc_pred,
                                  use_loc_filter=False):
        """Get guided anchors and loc masks for a single level.

        Args:
            square (tensor): Squares of a single level.
            shape_pred (tensor): Shape predections of a single level.
            loc_pred (tensor): Loc predections of a single level.
            use_loc_filter (list[tensor]): Use loc filter or not.

        Returns:
            tuple: guided anchors, location masks
        """
        # calculate location filtering mask
        loc_pred = loc_pred.sigmoid().detach()
        if use_loc_filter:
            loc_mask = loc_pred >= self.loc_filter_thr
        else:
            loc_mask = loc_pred >= 0.0
        mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_anchors)
        mask = mask.contiguous().view(-1)
        # calculate guided anchors
        squares = squares[mask]
        anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
            -1, 2).detach()[mask]
        bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
        bbox_deltas[:, 2:] = anchor_deltas
        guided_anchors = delta2bbox(squares,
                                    bbox_deltas,
                                    self.anchoring_means,
                                    self.anchoring_stds,
                                    wh_ratio_clip=1e-6)
        return guided_anchors, mask
Exemple #10
0
 def get_bboxes_single(self,
                       cls_scores,
                       bbox_preds,
                       mlvl_anchors,
                       mlvl_masks,
                       img_shape,
                       scale_factor,
                       cfg,
                       rescale=False):
     assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
     mlvl_bboxes = []
     mlvl_scores = []
     for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
                                                    mlvl_anchors,
                                                    mlvl_masks):
         assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
         # if no location is kept, end.
         if mask.sum() == 0:
             continue
         # reshape scores and bbox_pred
         cls_score = cls_score.permute(1, 2,
                                       0).reshape(-1, self.cls_out_channels)
         if self.use_sigmoid_cls:
             scores = cls_score.sigmoid()
         else:
             scores = cls_score.softmax(-1)
         bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
         # filter scores, bbox_pred w.r.t. mask.
         # anchors are filtered in get_anchors() beforehand.
         scores = scores[mask, :]
         bbox_pred = bbox_pred[mask, :]
         if scores.dim() == 0:
             anchors = anchors.unsqueeze(0)
             scores = scores.unsqueeze(0)
             bbox_pred = bbox_pred.unsqueeze(0)
         # filter anchors, bbox_pred, scores w.r.t. scores
         nms_pre = cfg.get('nms_pre', -1)
         if nms_pre > 0 and scores.shape[0] > nms_pre:
             if self.use_sigmoid_cls:
                 max_scores, _ = scores.max(dim=1)
             else:
                 max_scores, _ = scores[:, 1:].max(dim=1)
             _, topk_inds = max_scores.topk(nms_pre)
             anchors = anchors[topk_inds, :]
             bbox_pred = bbox_pred[topk_inds, :]
             scores = scores[topk_inds, :]
         bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
                             self.target_stds, img_shape)
         mlvl_bboxes.append(bboxes)
         mlvl_scores.append(scores)
     mlvl_bboxes = torch.cat(mlvl_bboxes)
     if rescale:
         mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
     mlvl_scores = torch.cat(mlvl_scores)
     if self.use_sigmoid_cls:
         padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
         mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
     # multi class NMS
     det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
                                             cfg.score_thr, cfg.nms,
                                             cfg.max_per_img)
     return det_bboxes, det_labels
Exemple #11
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, _ = self.get_anchors(featmap_sizes, img_metas)
        anchors = [torch.cat(anchor) for anchor in anchor_list]

        # concatenate each level
        cls_scores = [
            cls.permute(0, 2, 3, 1).reshape(cls.size(0), -1,
                                            self.cls_out_channels)
            for cls in cls_scores
        ]
        bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
            for bbox_pred in bbox_preds
        ]
        cls_scores = torch.cat(cls_scores, dim=1)
        bbox_preds = torch.cat(bbox_preds, dim=1)

        cls_prob = torch.sigmoid(cls_scores)
        box_prob = []
        num_pos = 0
        positive_losses = []
        for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_,
                bbox_preds_) in enumerate(
                    zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)):
            gt_labels_ -= 1

            with torch.no_grad():
                # box_localization: a_{j}^{loc}, shape: [j, 4]
                pred_boxes = delta2bbox(anchors_, bbox_preds_,
                                        self.target_means, self.target_stds)

                # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
                object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes)

                # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
                t1 = self.bbox_thr
                t2 = object_box_iou.max(
                    dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)
                object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
                    min=0, max=1)

                # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
                num_obj = gt_labels_.size(0)
                indices = torch.stack(
                    [torch.arange(num_obj).type_as(gt_labels_), gt_labels_],
                    dim=0)
                object_cls_box_prob = torch.sparse_coo_tensor(
                    indices, object_box_prob)

                # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
                """
                from "start" to "end" implement:
                image_box_iou = torch.sparse.max(object_cls_box_prob,
                                                 dim=0).t()

                """
                # start
                box_cls_prob = torch.sparse.sum(object_cls_box_prob,
                                                dim=0).to_dense()

                indices = torch.nonzero(box_cls_prob).t_()
                if indices.numel() == 0:
                    image_box_prob = torch.zeros(
                        anchors_.size(0),
                        self.cls_out_channels).type_as(object_box_prob)
                else:
                    nonzero_box_prob = torch.where(
                        (gt_labels_.unsqueeze(dim=-1) == indices[0]),
                        object_box_prob[:, indices[1]],
                        torch.tensor(
                            [0]).type_as(object_box_prob)).max(dim=0).values

                    # upmap to shape [j, c]
                    image_box_prob = torch.sparse_coo_tensor(
                        indices.flip([0]),
                        nonzero_box_prob,
                        size=(anchors_.size(0),
                              self.cls_out_channels)).to_dense()
                # end

                box_prob.append(image_box_prob)

            # construct bags for objects
            match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_)
            _, matched = torch.topk(match_quality_matrix,
                                    self.pre_anchor_topk,
                                    dim=1,
                                    sorted=False)
            del match_quality_matrix

            # matched_cls_prob: P_{ij}^{cls}
            matched_cls_prob = torch.gather(
                cls_prob_[matched], 2,
                gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
                                                 1)).squeeze(2)

            # matched_box_prob: P_{ij}^{loc}
            matched_anchors = anchors_[matched]
            matched_object_targets = bbox2delta(
                matched_anchors,
                gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors),
                self.target_means, self.target_stds)
            loss_bbox = self.loss_bbox(bbox_preds_[matched],
                                       matched_object_targets,
                                       reduction_override='none').sum(-1)
            matched_box_prob = torch.exp(-loss_bbox)

            # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
            num_pos += len(gt_bboxes_)
            positive_losses.append(
                self.positive_bag_loss(matched_cls_prob, matched_box_prob))
        positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)

        # box_prob: P{a_{j} \in A_{+}}
        box_prob = torch.stack(box_prob, dim=0)

        # negative_loss:
        # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
        negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
            1, num_pos * self.pre_anchor_topk)

        losses = {
            'positive_bag_loss': positive_loss,
            'negative_bag_loss': negative_loss
        }
        return losses
Exemple #12
0
 def get_bboxes_single(self,
                       cls_scores,
                       bbox_preds,
                       mlvl_anchors,
                       mlvl_masks,
                       img_shape,
                       scale_factor,
                       cfg,
                       rescale=False):
     mlvl_proposals = []
     for idx in range(len(cls_scores)):
         rpn_cls_score = cls_scores[idx]
         rpn_bbox_pred = bbox_preds[idx]
         anchors = mlvl_anchors[idx]
         mask = mlvl_masks[idx]
         assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
         # if no location is kept, end.
         if mask.sum() == 0:
             continue
         rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
         if self.use_sigmoid_cls:
             rpn_cls_score = rpn_cls_score.reshape(-1)
             scores = rpn_cls_score.sigmoid()
         else:
             rpn_cls_score = rpn_cls_score.reshape(-1, 2)
             scores = rpn_cls_score.softmax(dim=1)[:, 1]
         # filter scores, bbox_pred w.r.t. mask.
         # anchors are filtered in get_anchors() beforehand.
         scores = scores[mask]
         rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1,
                                                                4)[mask, :]
         if scores.dim() == 0:
             rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0)
             anchors = anchors.unsqueeze(0)
             scores = scores.unsqueeze(0)
         # filter anchors, bbox_pred, scores w.r.t. scores
         if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
             _, topk_inds = scores.topk(cfg.nms_pre)
             rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
             anchors = anchors[topk_inds, :]
             scores = scores[topk_inds]
         # get proposals w.r.t. anchors and rpn_bbox_pred
         proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
                                self.target_stds, img_shape)
         # filter out too small bboxes
         if cfg.min_bbox_size > 0:
             w = proposals[:, 2] - proposals[:, 0] + 1
             h = proposals[:, 3] - proposals[:, 1] + 1
             valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
                                        (h >= cfg.min_bbox_size)).squeeze()
             proposals = proposals[valid_inds, :]
             scores = scores[valid_inds]
         proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
         # NMS in current level
         proposals, _ = nms(proposals, cfg.nms_thr)
         proposals = proposals[:cfg.nms_post, :]
         mlvl_proposals.append(proposals)
     proposals = torch.cat(mlvl_proposals, 0)
     if cfg.nms_across_levels:
         # NMS across multi levels
         proposals, _ = nms(proposals, cfg.nms_thr)
         proposals = proposals[:cfg.max_num, :]
     else:
         scores = proposals[:, 4]
         num = min(cfg.max_num, proposals.shape[0])
         _, topk_inds = scores.topk(num)
         proposals = proposals[topk_inds, :]
     return proposals