def simple_test(self, img, img_meta, rescale=False): if 'tile_offset' in img_meta[0]: # using tile-cropped TTA. force using aug_test instead of simple_test return self.aug_test(imgs=[img], img_metas=[img_meta], rescale=True) x = self.extract_feat(img) outs = self.bbox_head(x) rois = self.bbox_head.filter_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) for i in range(self.num_refine_stages): x_refine = self.feat_refine_module[i](x, rois) outs = self.refine_head[i](x_refine) if i + 1 in range(self.num_refine_stages): rois = self.refine_head[i].refine_bboxes(*outs, rois=rois) bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois) bbox_results = [ rbbox2result(det_bboxes, det_labels, self.refine_head[-1].num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results[0]
def aug_test(self, imgs, img_metas, rescale=True): AUG_BS = 8 assert rescale, '''while r3det uses overlapped cropping augmentation by default, the result should be rescaled to input images sizes to simplify the test pipeline''' if 'tile_offset' in img_metas[0][0]: assert imgs[0].size( 0) == 1, '''when using cropped tiles augmentation, image batch size must be set to 1''' aug_det_bboxes, aug_det_labels = [], [] num_augs = len(imgs) for idx in range(0, num_augs, AUG_BS): img = imgs[idx:idx + AUG_BS] img_meta = img_metas[idx:idx + AUG_BS] act_num_augs = len(img_meta) img = torch.cat(img, dim=0) img_meta = sum(img_meta, []) # for img, img_meta in zip(imgs, img_metas): x = self.extract_feat(img) outs = self.bbox_head(x) rois = self.bbox_head.filter_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) det_bbox_bs = [[] for _ in range(act_num_augs)] det_label_bs = [[] for _ in range(act_num_augs)] for i in range(self.num_refine_stages): # x_refine = self.feat_refine_module[i](x, rois) x_refine = x outs = self.refine_head[i](x_refine) if i + 1 in range(self.num_refine_stages): rois = self.refine_head[i].refine_bboxes(*outs, rois=rois) bbox_inputs = outs + (img_meta, self.test_cfg, False) bbox_bs = self.refine_head[i].get_bboxes(*bbox_inputs, rois=rois) # [(rbbox_aug0, class_aug0), (rbbox_aug1, class_aug1), (rbbox_aug2, class_aug2), ...] for j in range(act_num_augs): det_bbox_bs[j].append(bbox_bs[j][0]) det_label_bs[j].append(bbox_bs[j][1]) for j in range(act_num_augs): det_bbox_bs[j] = torch.cat(det_bbox_bs[j]) det_label_bs[j] = torch.cat(det_label_bs[j]) aug_det_bboxes += det_bbox_bs aug_det_labels += det_label_bs aug_det_bboxes, aug_det_labels = merge_tiles_aug_rbboxes( aug_det_bboxes, aug_det_labels, img_metas, self.test_cfg.merge_cfg, self.CLASSES) return rbbox2result(aug_det_bboxes, aug_det_labels, self.refine_head[-1].num_classes) else: raise NotImplementedError
def simple_test(self, img, img_meta, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, "Bbox head must be implemented." x = self.extract_feat(img) proposal_rotate_list = self.simple_test_rpn( x, img_meta, self.test_cfg.rpn) if proposals is None else proposals if self.with_bbox: rois = rbboxPoly2rroiRec(proposal_rotate_list) bbox_cls_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) bbox_reg_feats = bbox_cls_feats if self.with_shared_head: bbox_cls_feats = self.shared_head(bbox_cls_feats) bbox_reg_feats = self.shared_head(bbox_reg_feats) cls_score, bbox_xy_pred, bbox_wh_pred, bbox_theta_pred = self.bbox_head( bbox_cls_feats, bbox_reg_feats) img_shape = img_meta[0]['ori_shape'] scale_factor = img_meta[0]['scale_factor'] det_bboxes, det_labels = self.bbox_head.get_det_rbbox2rbbox( rois, cls_score, bbox_xy_pred, bbox_wh_pred, bbox_theta_pred, img_shape, scale_factor, rescale=True, cfg=self.test_cfg.rcnn) rbbox_results = rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) if not self.with_mask: return rbbox_results else: return
def aug_test(self, imgs, img_metas, rescale=False): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """ # recompute feats to save memory proposal_list = self.aug_test_rotate_rpn(self.extract_feats(imgs), img_metas, self.test_cfg.rpn) aug_bboxes = [] aug_scores = [] for x, img_meta in zip(self.extract_feats(imgs), img_metas): # only one image in the batch img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] flip = img_meta[0]['flip'] # TODO more flexible proposals = [ bbox_mapping(proposal_list[0][:, :8], img_shape, scale_factor, flip) ] rois = rbboxPoly2rroiRec(proposals) bbox_cls_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) bbox_reg_feats = bbox_cls_feats if self.with_shared_head: bbox_cls_feats = self.shared_head(bbox_cls_feats) bbox_reg_feats = self.shared_head(bbox_reg_feats) cls_score, bbox_xy_pred, bbox_wh_pred, bbox_theta_pred = self.bbox_head( bbox_cls_feats, bbox_reg_feats) img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] bboxes, scores = self.bbox_head.get_det_rbbox2rbbox( rois, cls_score, bbox_xy_pred, bbox_wh_pred, bbox_theta_pred, img_shape, scale_factor, rescale=False, cfg=None) aug_bboxes.append(bboxes) aug_scores.append(scores) merged_bboxes, merged_scores = merge_aug_rotate_bboxes( aug_bboxes, aug_scores, img_metas, self.test_cfg.rcnn) det_bboxes, det_labels = multiclass_poly_nms_8_points( merged_bboxes, merged_scores, self.test_cfg.rcnn.score_thr, self.test_cfg.rcnn.nms, max_num=self.test_cfg.rcnn.max_per_img) bbox_results = rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) # det_bboxes always keep the original scale if self.with_mask: pass else: return bbox_results