def forward(self, feat, x):
        module = self.module

        cls_scores, bbox_preds, centernesses = module(feat)

        mlvl_anchors = self.anchor_generator(cls_scores,
                                             device=cls_scores[0].device)

        mlvl_scores = []
        mlvl_proposals = []
        mlvl_centerness = []
        nms_pre = self.test_cfg.get('nms_pre', -1)
        for cls_score, bbox_pred, centerness, anchors in zip(
                cls_scores, bbox_preds, centernesses, mlvl_anchors):
            centerness = centerness.permute(0, 2, 3,
                                            1).reshape(centerness.shape[0],
                                                       -1).sigmoid()
            scores, proposals = self.bbox_coder(
                cls_score,
                bbox_pred,
                anchors,
                min_num_bboxes=-1,
                num_classes=cls_score.shape[1] * 4 // bbox_pred.shape[1],
                use_sigmoid_cls=True,
                input_x=x)

            if nms_pre > 0:
                scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                centerness = mm2trt_util.pad_with_value(centerness, 1, nms_pre)
                proposals = mm2trt_util.pad_with_value(proposals, 1, nms_pre)

                max_scores, _ = (scores * centerness[:, :, None]).max(dim=2)
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                proposals = mm2trt_util.gather_topk(proposals, 1, topk_inds)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)
                centerness = mm2trt_util.gather_topk(centerness, 1, topk_inds)

            mlvl_scores.append(scores)
            mlvl_proposals.append(proposals)
            mlvl_centerness.append(centerness)

        mlvl_scores = torch.cat(mlvl_scores, dim=1)
        mlvl_proposals = torch.cat(mlvl_proposals, dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness, dim=1)

        # mlvl_scores = mlvl_scores*mlvl_centerness[:, :, None]
        max_scores, _ = mlvl_scores.max(dim=2)
        topk_pre = max(1000, nms_pre)
        _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]),
                                       dim=1)
        mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds)
        mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)

        mlvl_scores = mm2trt_util.pad_with_value(mlvl_scores, 2, 1, 0.)

        num_bboxes = mlvl_proposals.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(
            mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img)

        return num_detected, proposals, scores, cls_id
    def forward(self, feat, x):
        module = self.module

        cls_scores, bbox_preds, iou_preds = module(feat)
        
        num_levels = len(cls_scores)
        mlvl_anchors = self.anchor_generator(cls_scores, device = cls_scores[0].device)
        
        mlvl_scores = []
        mlvl_proposals = []
        mlvl_iou_preds = []
        nms_pre = self.test_cfg.get('nms_pre', -1)
        for cls_score, bbox_pred, iou_pred, anchors in zip(
                cls_scores, bbox_preds, iou_preds, mlvl_anchors):
            iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(iou_pred.shape[0],-1).sigmoid()
            scores, proposals = self.bbox_coder(cls_score, 
                                                bbox_pred, 
                                                anchors, 
                                                min_num_bboxes = -1, 
                                                num_classes = cls_score.shape[1]*4//bbox_pred.shape[1],
                                                use_sigmoid_cls = True, 
                                                input_x = x
                                                )
                             
            if nms_pre>0:
                scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                iou_pred=mm2trt_util.pad_with_value(iou_pred, 1, nms_pre)
                proposals=mm2trt_util.pad_with_value(proposals, 1, nms_pre)

                max_scores, _ = (scores * iou_pred[:, :, None]).sqrt().max(dim=2)
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                proposals = mm2trt_util.gather_topk(proposals, 1, topk_inds)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)
                iou_pred = mm2trt_util.gather_topk(iou_pred, 1, topk_inds)
            
            mlvl_scores.append(scores)
            mlvl_proposals.append(proposals)
            mlvl_iou_preds.append(iou_pred)

        mlvl_scores = torch.cat(mlvl_scores, dim=1)
        mlvl_proposals = torch.cat(mlvl_proposals, dim=1)
        mlvl_iou_preds = torch.cat(mlvl_iou_preds, dim=1)

        mlvl_scores = (mlvl_scores*mlvl_iou_preds[:, :, None]).sqrt()
        max_scores, _ = mlvl_scores.max(dim=2)
        topk_pre = max(1000, nms_pre)
        _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1)
        mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds)
        mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)
        
        mlvl_scores=mm2trt_util.pad_with_value(mlvl_scores, 2, 1, 0.)

        num_bboxes = mlvl_proposals.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img)

        if module.with_score_voting:
            return self.score_voting_batched(num_detected, proposals, scores, cls_id, 
                                            mlvl_proposals, mlvl_scores, self.test_cfg.score_thr)

        return num_detected, proposals, scores, cls_id
    def forward(self,
                cls_scores,
                bbox_preds,
                anchors,
                min_num_bboxes,
                num_classes,
                use_sigmoid_cls,
                input_x=None):
        cls_scores = cls_scores.permute(0, 2, 3,
                                        1).reshape(cls_scores.shape[0], -1,
                                                   num_classes)
        if use_sigmoid_cls:
            scores = cls_scores.sigmoid()
        else:
            cls_scores = cls_scores
            scores = cls_scores.softmax(dim=2)

        bbox_preds = bbox_preds.permute(0, 2, 3,
                                        1).reshape(bbox_preds.shape[0], -1, 4)
        anchors = anchors.unsqueeze(0)

        max_shape = None if input_x is None else input_x.shape[2:]
        proposals = batched_blr2bboxes(anchors,
                                       bbox_preds,
                                       normalizer=self.normalizer,
                                       max_shape=max_shape)
        if min_num_bboxes > 0:
            scores = util_ops.pad_with_value(scores, 1, min_num_bboxes, 0)
            proposals = util_ops.pad_with_value(proposals, 1, min_num_bboxes)

        proposals = proposals.unsqueeze(2)
        return scores, proposals
    def forward(self, feat, x):
        module = self.module
        cfg = self.test_cfg
        dense_outputs = module(feat)

        if len(dense_outputs) == 3:
            # old
            cls_scores, _, bbox_preds_refine = dense_outputs
        else:
            # new
            cls_scores, bbox_preds_refine = dense_outputs

        mlvl_points = self.get_points(cls_scores)

        mlvl_bboxes = []
        mlvl_scores = []
        for cls_score, bbox_pred, points in zip(cls_scores, bbox_preds_refine,
                                                mlvl_points):
            scores = cls_score.permute(0, 2, 3, 1).reshape(
                cls_score.shape[0], -1, module.cls_out_channels).sigmoid()
            bbox_pred = bbox_pred.permute(0, 2, 3,
                                          1).reshape(bbox_pred.shape[0], -1, 4)
            points = points.unsqueeze(0)
            points = points.expand_as(bbox_pred[:, :, :2])
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0:
                # concate zero to enable topk,
                # dirty way, will find a better way in future
                scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                bbox_pred = mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre)
                points = mm2trt_util.pad_with_value(points, 1, nms_pre)

                # do topk
                max_scores, _ = scores.max(dim=2)
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                points = mm2trt_util.gather_topk(points, 1, topk_inds)
                bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)

            bboxes = batched_distance2bbox(points, bbox_pred, x.shape[2:])
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)

        mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
        mlvl_scores = torch.cat(mlvl_scores, dim=1)

        mlvl_proposals = mlvl_bboxes.unsqueeze(2)

        max_scores, _ = mlvl_scores.max(dim=2)
        topk_pre = max(1000, nms_pre)
        _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]),
                                       dim=1)
        mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds)
        mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)

        num_bboxes = mlvl_proposals.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(
            mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img)

        return num_detected, proposals, scores, cls_id
    def forward(self, feats, x):
        
        module = self.module
        cfg = self.test_cfg

        pred_maps_list = module(feats)[0]

        multi_lvl_anchors = self.anchor_generator(pred_maps_list, device = pred_maps_list[0].device)

        multi_lvl_bboxes=[]
        multi_lvl_cls_scores=[]
        multi_lvl_conf_scores=[]
        for i in range(self.num_levels):
            # get some key info for current scale
            pred_map = pred_maps_list[i]
            stride = self.featmap_strides[i]
            batch_size = pred_map.shape[0]
            pred_map = pred_map.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_attrib)
            pred_map[..., :2] = torch.sigmoid(pred_map[..., :2])
            pred_map_pre_proposal = torch.sigmoid(pred_map[..., :2])
            pred_map_post_proposal = pred_map[..., 2:4]
            pred_map_proposal = torch.cat([pred_map_pre_proposal, pred_map_post_proposal], dim=-1)
            anchors = multi_lvl_anchors[i].unsqueeze(0).expand_as(pred_map_proposal)
            bbox_pred = self.bbox_coder.decode(anchors,
                                               pred_map_proposal, stride)

            conf_pred = torch.sigmoid(pred_map[..., 4]).view(batch_size, -1)
            cls_pred = torch.sigmoid(pred_map[..., 5:]).view(batch_size,
                -1, self.num_classes)  # Cls pred one-hot.

            nms_pre = cfg.get('nms_pre', -1)
            if 0 < nms_pre:
                conf_pred = mm2trt_util.pad_with_value(conf_pred, 1, nms_pre, 0.)
                cls_pred = mm2trt_util.pad_with_value(cls_pred, 1, nms_pre)
                bbox_pred = mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre)
                _, topk_inds = conf_pred.topk(nms_pre, dim=1)
                conf_pred = mm2trt_util.gather_topk(conf_pred, 1, topk_inds)
                cls_pred = mm2trt_util.gather_topk(cls_pred, 1, topk_inds)
                bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds)

            conf_thr = cfg.get('conf_thr', -1)
            conf_inds = conf_pred.ge(conf_thr).float()
            conf_pred = conf_pred*conf_inds

            multi_lvl_bboxes.append(bbox_pred)
            multi_lvl_cls_scores.append(cls_pred)
            multi_lvl_conf_scores.append(conf_pred)

        multi_lvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
        multi_lvl_cls_scores = torch.cat(multi_lvl_cls_scores, dim=1)
        multi_lvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)

        multi_lvl_cls_scores = multi_lvl_cls_scores*multi_lvl_conf_scores.unsqueeze(2)
        multi_lvl_bboxes = multi_lvl_bboxes.unsqueeze(2)
        num_bboxes = multi_lvl_bboxes.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(multi_lvl_cls_scores, multi_lvl_bboxes, num_bboxes, self.test_cfg.max_per_img)
        return num_detected, proposals, scores, cls_id
    def forward(self, feat, x):
        img_shape = x.shape[2:]
        module = self.module
        cfg = self.test_cfg

        cls_scores, bbox_preds, shape_preds, loc_preds = module(feat)

        _, mlvl_anchors, mlvl_masks = self.get_anchors(
            cls_scores,
            shape_preds,
            loc_preds,
            use_loc_filter=True)
        
        mlvl_scores = []
        mlvl_proposals = []
        nms_pre = cfg.get('nms_pre', -1)
        for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
                                                       mlvl_anchors,
                                                       mlvl_masks):
            
            scores = cls_score.permute(0, 2, 3, 1).reshape(cls_score.shape[0],
                -1, module.cls_out_channels).sigmoid()
            if module.use_sigmoid_cls:
                scores = scores.sigmoid()
            else:
                scores = scores.softmax(-1)
                
            scores = scores*mask.unsqueeze(2)
            bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.shape[0], -1, 4)

            if nms_pre > 0:
                # concate zero to enable topk, dirty way, will find a better way in future
                scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                bbox_pred=mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre)
                anchors=mm2trt_util.pad_with_value(anchors, 1, nms_pre)

                # do topk
                max_scores, _ = (scores).max(dim=2)
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)
                anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds)
            
            proposals = self.bbox_coder.decode(
                anchors, bbox_pred, max_shape=img_shape)

            mlvl_scores.append(scores)
            mlvl_proposals.append(proposals)

        mlvl_scores = torch.cat(mlvl_scores, dim=1)
        mlvl_proposals = torch.cat(mlvl_proposals, dim=1)
        mlvl_proposals = mlvl_proposals.unsqueeze(2)

        max_scores, _ = mlvl_scores.max(dim=2)
        topk_pre = max(1000, nms_pre)
        _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1)
        mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)
        mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds)

        num_bboxes = mlvl_proposals.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img)

        return num_detected, proposals, scores, cls_id
Esempio n. 7
0
    def forward(self, feat, x):
        img_shape = x.shape[2:]
        module = self.module
        cfg = self.test_cfg
        
        cls_scores, _, pts_preds_refine = module(feat)

        bbox_preds_refine = [
            module.points2bbox(pts_pred_refine)
            for pts_pred_refine in pts_preds_refine
        ]

        num_levels = len(cls_scores)
        mlvl_points = [
            self.point_generators[i](cls_scores[i],
                                    module.point_strides[i])
            for i in range(num_levels)
        ]

        mlvl_bboxes = []
        mlvl_scores = []
        for i_lvl, (cls_score, bbox_pred, points) in enumerate(
                zip(cls_scores, bbox_preds_refine, mlvl_points)):
            scores = cls_score.permute(0, 2, 3, 1).reshape(cls_score.shape[0],
                -1, module.cls_out_channels).sigmoid()
            if module.use_sigmoid_cls:
                scores = scores.sigmoid()
            else:
                scores = scores.softmax(-1)[:,:,:-1]
            bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.shape[0], -1, 4)
            points = points[:, :2].unsqueeze(0).expand_as(bbox_pred[:,:,:2])

            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0:
                # concate zero to enable topk, dirty way, will find a better way in future
                scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                bbox_pred=mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre)
                points=mm2trt_util.pad_with_value(points, 1, nms_pre)
                max_scores, _ = (scores).max(dim=2)
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)
                points = mm2trt_util.gather_topk(points, 1, topk_inds)

            bbox_pos_center = torch.cat([points[:, :, :2], points[:, :, :2]], dim=2)
            bboxes = bbox_pred * module.point_strides[i_lvl] + bbox_pos_center
            x1 = bboxes[:, :, 0].clamp(min=0, max=img_shape[1])
            y1 = bboxes[:, :, 1].clamp(min=0, max=img_shape[0])
            x2 = bboxes[:, :, 2].clamp(min=0, max=img_shape[1])
            y2 = bboxes[:, :, 3].clamp(min=0, max=img_shape[0])
            bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
            
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
        mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
        mlvl_scores = torch.cat(mlvl_scores, dim=1)
        mlvl_bboxes = mlvl_bboxes.unsqueeze(2)

        # topk again
        if nms_pre > 0:
            max_scores, _ = (mlvl_scores).max(dim=2)
            _, topk_inds = max_scores.topk(nms_pre, dim=1)
            mlvl_bboxes = mm2trt_util.gather_topk(mlvl_bboxes, 1, topk_inds)
            mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)

        num_bboxes = mlvl_bboxes.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_bboxes, num_bboxes, self.test_cfg.max_per_img)
        return num_detected, proposals, scores, cls_id
    def forward(self, feat, x):
        img_shape = x.shape[2:]
        module = self.module
        cfg = self.test_cfg

        cls_scores, bbox_preds = module(feat)

        num_levels = len(cls_scores)
        mlvl_anchors = self.anchor_generator(cls_scores,
                                             device=cls_scores[0].device)

        mlvl_scores = []
        mlvl_proposals = []
        nms_pre = self.test_cfg.get('nms_pre', -1)
        for idx in range(num_levels):
            rpn_cls_score = cls_scores[idx]
            rpn_bbox_pred = bbox_preds[idx]
            anchors = mlvl_anchors[idx]
            stride = module.anchor_generator.strides[idx]
            scores = rpn_cls_score.permute(0, 2, 3, 1).reshape(
                rpn_cls_score.shape[0], -1, module.cls_out_channels).sigmoid()
            bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1)
            bbox_pred = self.batched_integral(module.integral,
                                              bbox_pred) * stride[0]
            anchors = anchors.unsqueeze(0)

            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0:
                # concate zero to enable topk,
                # dirty way, will find a better way in future
                scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                bbox_pred = mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre)
                anchors = mm2trt_util.pad_with_value(anchors, 1, nms_pre)

                # do topk
                max_scores, _ = scores.max(dim=2)
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)
                bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds)
                anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds)

            proposals = batched_distance2bbox(
                self.batched_anchor_center(anchors),
                bbox_pred,
                max_shape=img_shape)

            mlvl_scores.append(scores)
            mlvl_proposals.append(proposals)

        mlvl_scores = torch.cat(mlvl_scores, dim=1)
        mlvl_proposals = torch.cat(mlvl_proposals, dim=1)
        mlvl_proposals = mlvl_proposals.unsqueeze(2)

        topk_pre = max(1000, nms_pre)
        max_scores, _ = mlvl_scores.max(dim=2)
        _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.size(1)),
                                       dim=1)
        mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds)
        mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)

        num_bboxes = mlvl_proposals.shape[1]

        num_detected, proposals, scores, cls_id = self.rcnn_nms(
            mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img)

        return num_detected, proposals, scores, cls_id
Esempio n. 9
0
    def forward(self, feat, x):
        batch_size = feat[0].size(0)
        module = self.module
        img_shape = x.shape[2:]
        cfg = self.test_cfg

        cls_scores, bbox_preds = module(feat)

        num_levels = len(cls_scores)
        mlvl_anchors = self.square_anchor_generator(
            cls_scores, device=cls_scores[0].device)

        mlvl_scores = []
        mlvl_bboxes = []
        mlvl_confids = []
        nms_pre = self.test_cfg.get('nms_pre', -1)

        bbox_cls_preds = [bb[0] for bb in bbox_preds]
        bbox_reg_preds = [bb[1] for bb in bbox_preds]
        for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip(
                cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors):
            cls_score = cls_score.permute(0, 2, 3,
                                          1).reshape(batch_size, -1,
                                                     module.cls_out_channels)

            if self.use_sigmoid_cls:
                scores = cls_score.sigmoid()
            else:
                scores = cls_score.softmax(-1)

            bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(
                batch_size, -1, self.side_num * 4)
            bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape(
                batch_size, -1, self.side_num * 4)
            anchors = anchors.unsqueeze(0).expand_as(bbox_cls_pred[:, :, :4])

            # do topk
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0:
                # pad to make sure shape>nms_pred
                scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                bbox_cls_pred = mm2trt_util.pad_with_value(
                    bbox_cls_pred, 1, nms_pre)
                bbox_reg_pred = mm2trt_util.pad_with_value(
                    bbox_reg_pred, 1, nms_pre)
                anchors = mm2trt_util.pad_with_value(anchors, 1, nms_pre)
                if self.use_sigmoid_cls:
                    max_scores, _ = scores.max(dim=2)
                else:
                    max_scores, _ = scores[:, :-1].max(dim=2)

                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)
                bbox_cls_pred = mm2trt_util.gather_topk(
                    bbox_cls_pred, 1, topk_inds)
                bbox_reg_pred = mm2trt_util.gather_topk(
                    bbox_reg_pred, 1, topk_inds)
                anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds)

            bbox_preds = [
                bbox_cls_pred.contiguous(),
                bbox_reg_pred.contiguous()
            ]

            bboxes, confids = self.bbox_coder.decode(anchors.contiguous(),
                                                     bbox_preds,
                                                     max_shape=img_shape)

            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_confids.append(confids)
        mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
        mlvl_scores = torch.cat(mlvl_scores, dim=1)
        mlvl_confids = torch.cat(mlvl_confids, dim=1)

        mlvl_bboxes = mlvl_bboxes.unsqueeze(2)
        mlvl_scores = mlvl_scores * mlvl_confids.unsqueeze(-1)

        max_scores, _ = mlvl_scores.max(dim=2)
        topk_pre = max(1000, nms_pre)
        _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]),
                                       dim=1)
        mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)
        mlvl_bboxes = mm2trt_util.gather_topk(mlvl_bboxes, 1, topk_inds)

        if self.use_sigmoid_cls:
            padding = mlvl_scores[:, :, :1] * 0
            mlvl_scores = torch.cat([mlvl_scores, padding], dim=2)
        # if not self.use_sigmoid_cls:
        #     mlvl_scores = mlvl_scores[:,:,:-1]
        mlvl_bboxes = mlvl_bboxes.repeat(1, 1, self.num_classes + 1, 1)

        num_bboxes = mlvl_bboxes.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(
            mlvl_scores, mlvl_bboxes, num_bboxes, self.test_cfg.max_per_img)

        return num_detected, proposals, scores, cls_id
    def forward(self, feat, x):
        img_shape = x.shape[2:]
        module = self.module

        cls_scores, bbox_preds, shape_preds, loc_preds = module(feat)

        _, guided_anchors, loc_masks = self.get_anchors(cls_scores,
                                                        shape_preds,
                                                        loc_preds,
                                                        use_loc_filter=True)

        mlvl_scores = []
        mlvl_proposals = []
        nms_pre = self.test_cfg.get('nms_pre', -1)
        for idx in range(len(cls_scores)):
            rpn_cls_score = cls_scores[idx]  #.squeeze()
            rpn_bbox_pred = bbox_preds[idx]  #.squeeze()
            anchors = guided_anchors[idx]
            mask = loc_masks[idx]

            rpn_cls_score = rpn_cls_score.permute(0, 2, 3, 1)
            if self.use_sigmoid_cls:
                rpn_cls_score = rpn_cls_score.reshape(rpn_cls_score.shape[0],
                                                      -1)
                scores = rpn_cls_score.sigmoid()
            else:
                rpn_cls_score = rpn_cls_score.reshape(rpn_cls_score.shape[0],
                                                      -1, 2)
                scores = rpn_cls_score.softmax(dim=2)[:, :, :-1]
            scores = scores * mask

            rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).reshape(
                rpn_bbox_pred.size(0), -1, 4)
            if nms_pre > 0:
                # concate zero to enable topk, dirty way, will find a better way in future
                scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                bbox_pred = mm2trt_util.pad_with_value(rpn_bbox_pred, 1,
                                                       nms_pre)
                anchors = mm2trt_util.pad_with_value(anchors, 1, nms_pre)

                # do topk
                # max_scores, _ = scores.max(dim=2)
                max_scores = scores
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds)
                bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)

            proposals = self.bbox_coder.decode(anchors,
                                               bbox_pred,
                                               max_shape=img_shape)

            scores = scores.unsqueeze(-1)
            proposals = proposals.unsqueeze(2)
            _, proposals, scores, _ = self.rpn_nms(scores, proposals,
                                                   self.test_cfg.nms_pre,
                                                   self.test_cfg.nms_post)

            mlvl_scores.append(scores.squeeze(0))
            mlvl_proposals.append(proposals.squeeze(0))

        scores = torch.cat(mlvl_scores, dim=0)
        proposals = torch.cat(mlvl_proposals, dim=0)

        _, topk_inds = scores.topk(self.test_cfg.max_num)
        proposals = proposals[topk_inds, :]

        return proposals
    def forward(self, feat, x):
        img_shape = x.shape[2:]
        module = self.module
        cfg = self.test_cfg
        cls_scores, bbox_preds = module(feat)
        mlvl_points = self.get_points(cls_scores, flatten=True)

        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        for cls_score, bbox_pred, stride, base_len, (y, x) in zip(
                cls_scores, bbox_preds, module.strides, module.base_edge_list, mlvl_points):
            scores = cls_score.permute(0, 2, 3, 1).reshape(cls_score.shape[0],
                -1, module.cls_out_channels).sigmoid()
            bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.shape[0], -1, 4).exp()
            x = x.unsqueeze(0)+0.5
            y = y.unsqueeze(0)+0.5
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0:
                # concate zero to enable topk, dirty way, will find a better way in future
                scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.)
                bbox_pred=mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre)
                y=mm2trt_util.pad_with_value(y, 1, nms_pre)
                x=mm2trt_util.pad_with_value(x, 1, nms_pre)

                # do topk
                max_scores, _ = (scores).max(dim=2)
                _, topk_inds = max_scores.topk(nms_pre, dim=1)
                bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds)
                scores = mm2trt_util.gather_topk(scores, 1, topk_inds)
                y = mm2trt_util.gather_topk(y, 1, topk_inds)
                x = mm2trt_util.gather_topk(x, 1, topk_inds)
            

            x1 = (stride * x - base_len * bbox_pred[:, :, 0]).\
                clamp(min=0, max=img_shape[1] - 1)
            y1 = (stride * y - base_len * bbox_pred[:, :, 1]).\
                clamp(min=0, max=img_shape[0] - 1)
            x2 = (stride * x + base_len * bbox_pred[:, :, 2]).\
                clamp(min=0, max=img_shape[1] - 1)
            y2 = (stride * y + base_len * bbox_pred[:, :, 3]).\
                clamp(min=0, max=img_shape[0] - 1)
            bboxes = torch.stack([x1, y1, x2, y2], -1)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)


        mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
        mlvl_scores = torch.cat(mlvl_scores, dim=1)

        mlvl_proposals = mlvl_bboxes.unsqueeze(2)

        max_scores, _ = mlvl_scores.max(dim=2)
        topk_pre = max(1000, nms_pre)
        _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1)
        mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds)
        mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds)

        num_bboxes = mlvl_proposals.shape[1]
        num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img)

        return num_detected, proposals, scores, cls_id