def forward(self, feat, x): module = self.module cls_scores, bbox_preds, centernesses = module(feat) mlvl_anchors = self.anchor_generator(cls_scores, device=cls_scores[0].device) mlvl_scores = [] mlvl_proposals = [] mlvl_centerness = [] nms_pre = self.test_cfg.get('nms_pre', -1) for cls_score, bbox_pred, centerness, anchors in zip( cls_scores, bbox_preds, centernesses, mlvl_anchors): centerness = centerness.permute(0, 2, 3, 1).reshape(centerness.shape[0], -1).sigmoid() scores, proposals = self.bbox_coder( cls_score, bbox_pred, anchors, min_num_bboxes=-1, num_classes=cls_score.shape[1] * 4 // bbox_pred.shape[1], use_sigmoid_cls=True, input_x=x) if nms_pre > 0: scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) centerness = mm2trt_util.pad_with_value(centerness, 1, nms_pre) proposals = mm2trt_util.pad_with_value(proposals, 1, nms_pre) max_scores, _ = (scores * centerness[:, :, None]).max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) proposals = mm2trt_util.gather_topk(proposals, 1, topk_inds) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) centerness = mm2trt_util.gather_topk(centerness, 1, topk_inds) mlvl_scores.append(scores) mlvl_proposals.append(proposals) mlvl_centerness.append(centerness) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_proposals = torch.cat(mlvl_proposals, dim=1) mlvl_centerness = torch.cat(mlvl_centerness, dim=1) # mlvl_scores = mlvl_scores*mlvl_centerness[:, :, None] max_scores, _ = mlvl_scores.max(dim=2) topk_pre = max(1000, nms_pre) _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1) mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) mlvl_scores = mm2trt_util.pad_with_value(mlvl_scores, 2, 1, 0.) num_bboxes = mlvl_proposals.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms( mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id
def forward(self, feat, x): module = self.module cls_scores, bbox_preds, iou_preds = module(feat) num_levels = len(cls_scores) mlvl_anchors = self.anchor_generator(cls_scores, device = cls_scores[0].device) mlvl_scores = [] mlvl_proposals = [] mlvl_iou_preds = [] nms_pre = self.test_cfg.get('nms_pre', -1) for cls_score, bbox_pred, iou_pred, anchors in zip( cls_scores, bbox_preds, iou_preds, mlvl_anchors): iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(iou_pred.shape[0],-1).sigmoid() scores, proposals = self.bbox_coder(cls_score, bbox_pred, anchors, min_num_bboxes = -1, num_classes = cls_score.shape[1]*4//bbox_pred.shape[1], use_sigmoid_cls = True, input_x = x ) if nms_pre>0: scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) iou_pred=mm2trt_util.pad_with_value(iou_pred, 1, nms_pre) proposals=mm2trt_util.pad_with_value(proposals, 1, nms_pre) max_scores, _ = (scores * iou_pred[:, :, None]).sqrt().max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) proposals = mm2trt_util.gather_topk(proposals, 1, topk_inds) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) iou_pred = mm2trt_util.gather_topk(iou_pred, 1, topk_inds) mlvl_scores.append(scores) mlvl_proposals.append(proposals) mlvl_iou_preds.append(iou_pred) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_proposals = torch.cat(mlvl_proposals, dim=1) mlvl_iou_preds = torch.cat(mlvl_iou_preds, dim=1) mlvl_scores = (mlvl_scores*mlvl_iou_preds[:, :, None]).sqrt() max_scores, _ = mlvl_scores.max(dim=2) topk_pre = max(1000, nms_pre) _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1) mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) mlvl_scores=mm2trt_util.pad_with_value(mlvl_scores, 2, 1, 0.) num_bboxes = mlvl_proposals.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img) if module.with_score_voting: return self.score_voting_batched(num_detected, proposals, scores, cls_id, mlvl_proposals, mlvl_scores, self.test_cfg.score_thr) return num_detected, proposals, scores, cls_id
def forward(self, cls_scores, bbox_preds, anchors, min_num_bboxes, num_classes, use_sigmoid_cls, input_x=None): cls_scores = cls_scores.permute(0, 2, 3, 1).reshape(cls_scores.shape[0], -1, num_classes) if use_sigmoid_cls: scores = cls_scores.sigmoid() else: cls_scores = cls_scores scores = cls_scores.softmax(dim=2) bbox_preds = bbox_preds.permute(0, 2, 3, 1).reshape(bbox_preds.shape[0], -1, 4) anchors = anchors.unsqueeze(0) max_shape = None if input_x is None else input_x.shape[2:] proposals = batched_blr2bboxes(anchors, bbox_preds, normalizer=self.normalizer, max_shape=max_shape) if min_num_bboxes > 0: scores = util_ops.pad_with_value(scores, 1, min_num_bboxes, 0) proposals = util_ops.pad_with_value(proposals, 1, min_num_bboxes) proposals = proposals.unsqueeze(2) return scores, proposals
def forward(self, feat, x): module = self.module cfg = self.test_cfg dense_outputs = module(feat) if len(dense_outputs) == 3: # old cls_scores, _, bbox_preds_refine = dense_outputs else: # new cls_scores, bbox_preds_refine = dense_outputs mlvl_points = self.get_points(cls_scores) mlvl_bboxes = [] mlvl_scores = [] for cls_score, bbox_pred, points in zip(cls_scores, bbox_preds_refine, mlvl_points): scores = cls_score.permute(0, 2, 3, 1).reshape( cls_score.shape[0], -1, module.cls_out_channels).sigmoid() bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.shape[0], -1, 4) points = points.unsqueeze(0) points = points.expand_as(bbox_pred[:, :, :2]) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0: # concate zero to enable topk, # dirty way, will find a better way in future scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) bbox_pred = mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre) points = mm2trt_util.pad_with_value(points, 1, nms_pre) # do topk max_scores, _ = scores.max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) points = mm2trt_util.gather_topk(points, 1, topk_inds) bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) bboxes = batched_distance2bbox(points, bbox_pred, x.shape[2:]) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_proposals = mlvl_bboxes.unsqueeze(2) max_scores, _ = mlvl_scores.max(dim=2) topk_pre = max(1000, nms_pre) _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1) mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) num_bboxes = mlvl_proposals.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms( mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id
def forward(self, feats, x): module = self.module cfg = self.test_cfg pred_maps_list = module(feats)[0] multi_lvl_anchors = self.anchor_generator(pred_maps_list, device = pred_maps_list[0].device) multi_lvl_bboxes=[] multi_lvl_cls_scores=[] multi_lvl_conf_scores=[] for i in range(self.num_levels): # get some key info for current scale pred_map = pred_maps_list[i] stride = self.featmap_strides[i] batch_size = pred_map.shape[0] pred_map = pred_map.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_attrib) pred_map[..., :2] = torch.sigmoid(pred_map[..., :2]) pred_map_pre_proposal = torch.sigmoid(pred_map[..., :2]) pred_map_post_proposal = pred_map[..., 2:4] pred_map_proposal = torch.cat([pred_map_pre_proposal, pred_map_post_proposal], dim=-1) anchors = multi_lvl_anchors[i].unsqueeze(0).expand_as(pred_map_proposal) bbox_pred = self.bbox_coder.decode(anchors, pred_map_proposal, stride) conf_pred = torch.sigmoid(pred_map[..., 4]).view(batch_size, -1) cls_pred = torch.sigmoid(pred_map[..., 5:]).view(batch_size, -1, self.num_classes) # Cls pred one-hot. nms_pre = cfg.get('nms_pre', -1) if 0 < nms_pre: conf_pred = mm2trt_util.pad_with_value(conf_pred, 1, nms_pre, 0.) cls_pred = mm2trt_util.pad_with_value(cls_pred, 1, nms_pre) bbox_pred = mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre) _, topk_inds = conf_pred.topk(nms_pre, dim=1) conf_pred = mm2trt_util.gather_topk(conf_pred, 1, topk_inds) cls_pred = mm2trt_util.gather_topk(cls_pred, 1, topk_inds) bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds) conf_thr = cfg.get('conf_thr', -1) conf_inds = conf_pred.ge(conf_thr).float() conf_pred = conf_pred*conf_inds multi_lvl_bboxes.append(bbox_pred) multi_lvl_cls_scores.append(cls_pred) multi_lvl_conf_scores.append(conf_pred) multi_lvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1) multi_lvl_cls_scores = torch.cat(multi_lvl_cls_scores, dim=1) multi_lvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1) multi_lvl_cls_scores = multi_lvl_cls_scores*multi_lvl_conf_scores.unsqueeze(2) multi_lvl_bboxes = multi_lvl_bboxes.unsqueeze(2) num_bboxes = multi_lvl_bboxes.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms(multi_lvl_cls_scores, multi_lvl_bboxes, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id
def forward(self, feat, x): img_shape = x.shape[2:] module = self.module cfg = self.test_cfg cls_scores, bbox_preds, shape_preds, loc_preds = module(feat) _, mlvl_anchors, mlvl_masks = self.get_anchors( cls_scores, shape_preds, loc_preds, use_loc_filter=True) mlvl_scores = [] mlvl_proposals = [] nms_pre = cfg.get('nms_pre', -1) for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds, mlvl_anchors, mlvl_masks): scores = cls_score.permute(0, 2, 3, 1).reshape(cls_score.shape[0], -1, module.cls_out_channels).sigmoid() if module.use_sigmoid_cls: scores = scores.sigmoid() else: scores = scores.softmax(-1) scores = scores*mask.unsqueeze(2) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.shape[0], -1, 4) if nms_pre > 0: # concate zero to enable topk, dirty way, will find a better way in future scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) bbox_pred=mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre) anchors=mm2trt_util.pad_with_value(anchors, 1, nms_pre) # do topk max_scores, _ = (scores).max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds) proposals = self.bbox_coder.decode( anchors, bbox_pred, max_shape=img_shape) mlvl_scores.append(scores) mlvl_proposals.append(proposals) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_proposals = torch.cat(mlvl_proposals, dim=1) mlvl_proposals = mlvl_proposals.unsqueeze(2) max_scores, _ = mlvl_scores.max(dim=2) topk_pre = max(1000, nms_pre) _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds) num_bboxes = mlvl_proposals.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id
def forward(self, feat, x): img_shape = x.shape[2:] module = self.module cfg = self.test_cfg cls_scores, _, pts_preds_refine = module(feat) bbox_preds_refine = [ module.points2bbox(pts_pred_refine) for pts_pred_refine in pts_preds_refine ] num_levels = len(cls_scores) mlvl_points = [ self.point_generators[i](cls_scores[i], module.point_strides[i]) for i in range(num_levels) ] mlvl_bboxes = [] mlvl_scores = [] for i_lvl, (cls_score, bbox_pred, points) in enumerate( zip(cls_scores, bbox_preds_refine, mlvl_points)): scores = cls_score.permute(0, 2, 3, 1).reshape(cls_score.shape[0], -1, module.cls_out_channels).sigmoid() if module.use_sigmoid_cls: scores = scores.sigmoid() else: scores = scores.softmax(-1)[:,:,:-1] bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.shape[0], -1, 4) points = points[:, :2].unsqueeze(0).expand_as(bbox_pred[:,:,:2]) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0: # concate zero to enable topk, dirty way, will find a better way in future scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) bbox_pred=mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre) points=mm2trt_util.pad_with_value(points, 1, nms_pre) max_scores, _ = (scores).max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) points = mm2trt_util.gather_topk(points, 1, topk_inds) bbox_pos_center = torch.cat([points[:, :, :2], points[:, :, :2]], dim=2) bboxes = bbox_pred * module.point_strides[i_lvl] + bbox_pos_center x1 = bboxes[:, :, 0].clamp(min=0, max=img_shape[1]) y1 = bboxes[:, :, 1].clamp(min=0, max=img_shape[0]) x2 = bboxes[:, :, 2].clamp(min=0, max=img_shape[1]) y2 = bboxes[:, :, 3].clamp(min=0, max=img_shape[0]) bboxes = torch.stack([x1, y1, x2, y2], dim=-1) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_bboxes = mlvl_bboxes.unsqueeze(2) # topk again if nms_pre > 0: max_scores, _ = (mlvl_scores).max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) mlvl_bboxes = mm2trt_util.gather_topk(mlvl_bboxes, 1, topk_inds) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) num_bboxes = mlvl_bboxes.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_bboxes, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id
def forward(self, feat, x): img_shape = x.shape[2:] module = self.module cfg = self.test_cfg cls_scores, bbox_preds = module(feat) num_levels = len(cls_scores) mlvl_anchors = self.anchor_generator(cls_scores, device=cls_scores[0].device) mlvl_scores = [] mlvl_proposals = [] nms_pre = self.test_cfg.get('nms_pre', -1) for idx in range(num_levels): rpn_cls_score = cls_scores[idx] rpn_bbox_pred = bbox_preds[idx] anchors = mlvl_anchors[idx] stride = module.anchor_generator.strides[idx] scores = rpn_cls_score.permute(0, 2, 3, 1).reshape( rpn_cls_score.shape[0], -1, module.cls_out_channels).sigmoid() bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1) bbox_pred = self.batched_integral(module.integral, bbox_pred) * stride[0] anchors = anchors.unsqueeze(0) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0: # concate zero to enable topk, # dirty way, will find a better way in future scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) bbox_pred = mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre) anchors = mm2trt_util.pad_with_value(anchors, 1, nms_pre) # do topk max_scores, _ = scores.max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds) anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds) proposals = batched_distance2bbox( self.batched_anchor_center(anchors), bbox_pred, max_shape=img_shape) mlvl_scores.append(scores) mlvl_proposals.append(proposals) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_proposals = torch.cat(mlvl_proposals, dim=1) mlvl_proposals = mlvl_proposals.unsqueeze(2) topk_pre = max(1000, nms_pre) max_scores, _ = mlvl_scores.max(dim=2) _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.size(1)), dim=1) mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) num_bboxes = mlvl_proposals.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms( mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id
def forward(self, feat, x): batch_size = feat[0].size(0) module = self.module img_shape = x.shape[2:] cfg = self.test_cfg cls_scores, bbox_preds = module(feat) num_levels = len(cls_scores) mlvl_anchors = self.square_anchor_generator( cls_scores, device=cls_scores[0].device) mlvl_scores = [] mlvl_bboxes = [] mlvl_confids = [] nms_pre = self.test_cfg.get('nms_pre', -1) bbox_cls_preds = [bb[0] for bb in bbox_preds] bbox_reg_preds = [bb[1] for bb in bbox_preds] for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip( cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors): cls_score = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, module.cls_out_channels) if self.use_sigmoid_cls: scores = cls_score.sigmoid() else: scores = cls_score.softmax(-1) bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape( batch_size, -1, self.side_num * 4) bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape( batch_size, -1, self.side_num * 4) anchors = anchors.unsqueeze(0).expand_as(bbox_cls_pred[:, :, :4]) # do topk nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0: # pad to make sure shape>nms_pred scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) bbox_cls_pred = mm2trt_util.pad_with_value( bbox_cls_pred, 1, nms_pre) bbox_reg_pred = mm2trt_util.pad_with_value( bbox_reg_pred, 1, nms_pre) anchors = mm2trt_util.pad_with_value(anchors, 1, nms_pre) if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=2) else: max_scores, _ = scores[:, :-1].max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) bbox_cls_pred = mm2trt_util.gather_topk( bbox_cls_pred, 1, topk_inds) bbox_reg_pred = mm2trt_util.gather_topk( bbox_reg_pred, 1, topk_inds) anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds) bbox_preds = [ bbox_cls_pred.contiguous(), bbox_reg_pred.contiguous() ] bboxes, confids = self.bbox_coder.decode(anchors.contiguous(), bbox_preds, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_confids.append(confids) mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_confids = torch.cat(mlvl_confids, dim=1) mlvl_bboxes = mlvl_bboxes.unsqueeze(2) mlvl_scores = mlvl_scores * mlvl_confids.unsqueeze(-1) max_scores, _ = mlvl_scores.max(dim=2) topk_pre = max(1000, nms_pre) _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) mlvl_bboxes = mm2trt_util.gather_topk(mlvl_bboxes, 1, topk_inds) if self.use_sigmoid_cls: padding = mlvl_scores[:, :, :1] * 0 mlvl_scores = torch.cat([mlvl_scores, padding], dim=2) # if not self.use_sigmoid_cls: # mlvl_scores = mlvl_scores[:,:,:-1] mlvl_bboxes = mlvl_bboxes.repeat(1, 1, self.num_classes + 1, 1) num_bboxes = mlvl_bboxes.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms( mlvl_scores, mlvl_bboxes, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id
def forward(self, feat, x): img_shape = x.shape[2:] module = self.module cls_scores, bbox_preds, shape_preds, loc_preds = module(feat) _, guided_anchors, loc_masks = self.get_anchors(cls_scores, shape_preds, loc_preds, use_loc_filter=True) mlvl_scores = [] mlvl_proposals = [] nms_pre = self.test_cfg.get('nms_pre', -1) for idx in range(len(cls_scores)): rpn_cls_score = cls_scores[idx] #.squeeze() rpn_bbox_pred = bbox_preds[idx] #.squeeze() anchors = guided_anchors[idx] mask = loc_masks[idx] rpn_cls_score = rpn_cls_score.permute(0, 2, 3, 1) if self.use_sigmoid_cls: rpn_cls_score = rpn_cls_score.reshape(rpn_cls_score.shape[0], -1) scores = rpn_cls_score.sigmoid() else: rpn_cls_score = rpn_cls_score.reshape(rpn_cls_score.shape[0], -1, 2) scores = rpn_cls_score.softmax(dim=2)[:, :, :-1] scores = scores * mask rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).reshape( rpn_bbox_pred.size(0), -1, 4) if nms_pre > 0: # concate zero to enable topk, dirty way, will find a better way in future scores = mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) bbox_pred = mm2trt_util.pad_with_value(rpn_bbox_pred, 1, nms_pre) anchors = mm2trt_util.pad_with_value(anchors, 1, nms_pre) # do topk # max_scores, _ = scores.max(dim=2) max_scores = scores _, topk_inds = max_scores.topk(nms_pre, dim=1) anchors = mm2trt_util.gather_topk(anchors, 1, topk_inds) bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) proposals = self.bbox_coder.decode(anchors, bbox_pred, max_shape=img_shape) scores = scores.unsqueeze(-1) proposals = proposals.unsqueeze(2) _, proposals, scores, _ = self.rpn_nms(scores, proposals, self.test_cfg.nms_pre, self.test_cfg.nms_post) mlvl_scores.append(scores.squeeze(0)) mlvl_proposals.append(proposals.squeeze(0)) scores = torch.cat(mlvl_scores, dim=0) proposals = torch.cat(mlvl_proposals, dim=0) _, topk_inds = scores.topk(self.test_cfg.max_num) proposals = proposals[topk_inds, :] return proposals
def forward(self, feat, x): img_shape = x.shape[2:] module = self.module cfg = self.test_cfg cls_scores, bbox_preds = module(feat) mlvl_points = self.get_points(cls_scores, flatten=True) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, stride, base_len, (y, x) in zip( cls_scores, bbox_preds, module.strides, module.base_edge_list, mlvl_points): scores = cls_score.permute(0, 2, 3, 1).reshape(cls_score.shape[0], -1, module.cls_out_channels).sigmoid() bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.shape[0], -1, 4).exp() x = x.unsqueeze(0)+0.5 y = y.unsqueeze(0)+0.5 nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0: # concate zero to enable topk, dirty way, will find a better way in future scores=mm2trt_util.pad_with_value(scores, 1, nms_pre, 0.) bbox_pred=mm2trt_util.pad_with_value(bbox_pred, 1, nms_pre) y=mm2trt_util.pad_with_value(y, 1, nms_pre) x=mm2trt_util.pad_with_value(x, 1, nms_pre) # do topk max_scores, _ = (scores).max(dim=2) _, topk_inds = max_scores.topk(nms_pre, dim=1) bbox_pred = mm2trt_util.gather_topk(bbox_pred, 1, topk_inds) scores = mm2trt_util.gather_topk(scores, 1, topk_inds) y = mm2trt_util.gather_topk(y, 1, topk_inds) x = mm2trt_util.gather_topk(x, 1, topk_inds) x1 = (stride * x - base_len * bbox_pred[:, :, 0]).\ clamp(min=0, max=img_shape[1] - 1) y1 = (stride * y - base_len * bbox_pred[:, :, 1]).\ clamp(min=0, max=img_shape[0] - 1) x2 = (stride * x + base_len * bbox_pred[:, :, 2]).\ clamp(min=0, max=img_shape[1] - 1) y2 = (stride * y + base_len * bbox_pred[:, :, 3]).\ clamp(min=0, max=img_shape[0] - 1) bboxes = torch.stack([x1, y1, x2, y2], -1) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) mlvl_scores = torch.cat(mlvl_scores, dim=1) mlvl_proposals = mlvl_bboxes.unsqueeze(2) max_scores, _ = mlvl_scores.max(dim=2) topk_pre = max(1000, nms_pre) _, topk_inds = max_scores.topk(min(topk_pre, mlvl_scores.shape[1]), dim=1) mlvl_proposals = mm2trt_util.gather_topk(mlvl_proposals, 1, topk_inds) mlvl_scores = mm2trt_util.gather_topk(mlvl_scores, 1, topk_inds) num_bboxes = mlvl_proposals.shape[1] num_detected, proposals, scores, cls_id = self.rcnn_nms(mlvl_scores, mlvl_proposals, num_bboxes, self.test_cfg.max_per_img) return num_detected, proposals, scores, cls_id