def simple_test(self, img, img_meta, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, "Bbox head must be implemented." x = self.extract_feat(img) proposal_list = self.simple_test_rpn( x, img_meta, self.test_cfg.rpn) if proposals is None else proposals det_bboxes, det_labels = self.simple_test_bboxes(x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) bbox_results = bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) if not self.with_mask: return bbox_results else: segm_results = self.simple_test_mask(x, img_meta, det_bboxes, det_labels, rescale=rescale) return bbox_results, segm_results
def aug_test(self, imgs, img_metas, rescale=False): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """ # recompute feats to save memory proposal_list = self.aug_test_rpn(self.extract_feats(imgs), img_metas, self.test_cfg.rpn) det_bboxes, det_labels = self.aug_test_bboxes(self.extract_feats(imgs), img_metas, proposal_list, self.test_cfg.rcnn) if rescale: _det_bboxes = det_bboxes else: _det_bboxes = det_bboxes.clone() _det_bboxes[:, :4] *= img_metas[0][0]['scale_factor'] bbox_results = bbox2result(_det_bboxes, det_labels, self.bbox_head.num_classes) # det_bboxes always keep the original scale if self.with_mask: segm_results = self.aug_test_mask(self.extract_feats(imgs), img_metas, det_bboxes, det_labels) return bbox_results, segm_results else: return bbox_results
def simple_test(self, img, img_meta, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, "Bbox head must be implemented." x = self.extract_feat(img) proposal_list = self.simple_test_rpn( x, img_meta, self.test_cfg.rpn) if proposals is None else proposals det_bboxes, det_labels = self.simple_test_bboxes(x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=False) # pack rois into bboxes grid_rois = bbox2roi([det_bboxes[:, :4]]) grid_feats = self.grid_roi_extractor( x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois) if grid_rois.shape[0] != 0: self.grid_head.test_mode = True grid_pred = self.grid_head(grid_feats) det_bboxes = self.grid_head.get_bboxes(det_bboxes, grid_pred['fused'], img_meta) if rescale: det_bboxes[:, :4] /= img_meta[0]['scale_factor'] else: det_bboxes = torch.Tensor([]) bbox_results = bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) return bbox_results
def simple_test(self, img, img_meta, rescale=False): x = self.extract_feat(img) outs = self.bbox_head(x) bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_list = self.bbox_head.get_bboxes(*bbox_inputs) bbox_results = [ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results[0]
def simple_test(self, img, img_meta, proposals=None, rescale=False): x = self.extract_feat(img) proposal_list = self.simple_test_rpn( x, img_meta, self.test_cfg.rpn) if proposals is None else proposals if self.with_semantic: _, semantic_feat = self.semantic_head(x) else: semantic_feat = None img_shape = img_meta[0]['img_shape'] ori_shape = img_meta[0]['ori_shape'] scale_factor = img_meta[0]['scale_factor'] # "ms" in variable names means multi-stage ms_bbox_result = {} ms_segm_result = {} ms_scores = [] rcnn_test_cfg = self.test_cfg.rcnn rois = bbox2roi(proposal_list) for i in range(self.num_stages): bbox_head = self.bbox_head[i] cls_score, bbox_pred = self._bbox_forward_test( i, x, rois, semantic_feat=semantic_feat) ms_scores.append(cls_score) if self.test_cfg.keep_all_stages: det_bboxes, det_labels = bbox_head.get_det_bboxes( rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=rescale, nms_cfg=rcnn_test_cfg) bbox_result = bbox2result(det_bboxes, det_labels, bbox_head.num_classes) ms_bbox_result['stage{}'.format(i)] = bbox_result if self.with_mask: mask_head = self.mask_head[i] if det_bboxes.shape[0] == 0: segm_result = [[] for _ in range(mask_head.num_classes - 1)] else: _bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) mask_pred = self._mask_forward_test( i, x, _bboxes, semantic_feat=semantic_feat) segm_result = mask_head.get_seg_masks( mask_pred, _bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale) ms_segm_result['stage{}'.format(i)] = segm_result if i < self.num_stages - 1: bbox_label = cls_score.argmax(dim=1) rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, img_meta[0]) cls_score = sum(ms_scores) / float(len(ms_scores)) det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=rescale, cfg=rcnn_test_cfg) bbox_result = bbox2result(det_bboxes, det_labels, self.bbox_head[-1].num_classes) ms_bbox_result['ensemble'] = bbox_result if self.with_mask: if det_bboxes.shape[0] == 0: segm_result = [[] for _ in range(self.mask_head[-1].num_classes - 1)] else: _bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) mask_rois = bbox2roi([_bboxes]) aug_masks = [] mask_roi_extractor = self.mask_roi_extractor[-1] mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) if self.with_semantic and 'mask' in self.semantic_fusion: mask_semantic_feat = self.semantic_roi_extractor( [semantic_feat], mask_rois) mask_feats += mask_semantic_feat last_feat = None for i in range(self.num_stages): mask_head = self.mask_head[i] if self.mask_info_flow: mask_pred, last_feat = mask_head(mask_feats, last_feat) else: mask_pred = mask_head(mask_feats) aug_masks.append(mask_pred.sigmoid().cpu().numpy()) merged_masks = merge_aug_masks(aug_masks, [img_meta] * self.num_stages, self.test_cfg.rcnn) segm_result = self.mask_head[-1].get_seg_masks( merged_masks, _bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale) ms_segm_result['ensemble'] = segm_result if not self.test_cfg.keep_all_stages: if self.with_mask: results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble']) else: results = ms_bbox_result['ensemble'] else: if self.with_mask: results = { stage: (ms_bbox_result[stage], ms_segm_result[stage]) for stage in ms_bbox_result } else: results = ms_bbox_result return results