def im_detect_bbox_aug(model, images, device): # Collect detections computed under different transformations boxlists_ts = [] for _ in range(len(images)): boxlists_ts.append([]) def add_preds_t(boxlists_t): for i, boxlist_t in enumerate(boxlists_t): if len(boxlists_ts[i]) == 0: # The first one is identity transform, no need to resize the boxlist boxlists_ts[i].append(boxlist_t) else: # Resize the boxlist as the first one boxlists_ts[i].append(boxlist_t.resize(boxlists_ts[i][0].size)) # Compute detections for the original image (identity transform) boxlists_i = im_detect_bbox( model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device ) add_preds_t(boxlists_i) # Perform detection on the horizontally flipped image if cfg.TEST.BBOX_AUG.H_FLIP: boxlists_hf = im_detect_bbox_hflip( model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device ) add_preds_t(boxlists_hf) # Compute detections at different scales for scale in cfg.TEST.BBOX_AUG.SCALES: max_size = cfg.TEST.BBOX_AUG.MAX_SIZE boxlists_scl = im_detect_bbox_scale( model, images, scale, max_size, device ) add_preds_t(boxlists_scl) if cfg.TEST.BBOX_AUG.SCALE_H_FLIP: boxlists_scl_hf = im_detect_bbox_scale( model, images, scale, max_size, device, hflip=True ) add_preds_t(boxlists_scl_hf) # Merge boxlists detected by different bbox aug params boxlists = [] for i, boxlist_ts in enumerate(boxlists_ts): bbox = torch.cat([boxlist_t.bbox for boxlist_t in boxlist_ts]) scores = torch.cat([boxlist_t.get_field('scores') for boxlist_t in boxlist_ts]) boxlist = BoxList(bbox, boxlist_ts[0].size, boxlist_ts[0].mode) boxlist.add_field('scores', scores) boxlists.append(boxlist) # Apply NMS and limit the final detections results = [] post_processor = make_roi_box_post_processor(cfg) for boxlist in boxlists: results.append(post_processor.filter_results(boxlist, cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES)) return results
def __init__(self, cfg, in_channels): super(ROIBoxHead, self).__init__() self.feature_extractor = make_roi_box_feature_extractor( cfg, in_channels) self.predictor = FastRCNNPredictor(cfg, self.feature_extractor.out_channels) # assert cfg.MODEL.CLS_AGNOSTIC_BBOX_REG self.post_processor = make_roi_box_post_processor(cfg) # self.loss_evaluator = make_roi_box_loss_evaluator(cfg, allow_low_quality_matches=True) self.WEAK = cfg.WEAK
def __init__(self, cfg): super(ROIBoxHead, self).__init__() self.feature_extractor = make_roi_box_feature_extractor(cfg) self.predictor = make_roi_box_predictor(cfg) self.post_processor = make_roi_box_post_processor(cfg) self.loss_evaluator = make_roi_box_loss_evaluator(cfg) try: cfg.FILTERMODE except AttributeError: self.filter_mode = False else: self.filter_mode = cfg.FILTERMODE