def nms_detections(pred_boxes, scores, nms_thresh, inds=None): dets = torch.cat((pred_boxes, scores.unsqueeze(1)), 1) keep = nms(dets, nms_thresh).long().view(-1) if inds is None: return pred_boxes[keep], scores[keep] return pred_boxes[keep], scores[keep], inds[keep]
def forward(self, input): # Algorithm: # # for each (H, W) location i # generate A anchor boxes centered on cell i # apply predicted bbox deltas at cell i to each of the A anchors # clip predicted boxes to image # remove predicted boxes with either height or width < threshold # sort all (proposal, score) pairs by score from highest to lowest # take top pre_nms_topN proposals before NMS # apply NMS with threshold 0.7 to remaining proposals # take after_nms_topN proposals after NMS # return the top proposals (-> RoIs top, scores top) # layer_params = yaml.load(self.param_str_) # the first set of _num_anchors channels are bg probs # the second set are the fg probs scores = input[0][:, self._num_anchors:, :, :] bbox_deltas = input[1] im_info = input[2] cfg_key = input[3] #assert rpn_cls_prob_reshape.shape[0] == 1, \ # 'Only single item batches are supported' # cfg_key = str(self.phase) # either 'TRAIN' or 'TEST' # cfg_key = 'TEST' pre_nms_topN = cfg[cfg_key].RPN_PRE_NMS_TOP_N post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N nms_thresh = cfg[cfg_key].RPN_NMS_THRESH min_size = cfg[cfg_key].RPN_MIN_SIZE batch_size = bbox_deltas.size(0) feat_height, feat_width = scores.size(2), scores.size(3) shift_x = np.arange(0, feat_width) * self._feat_stride shift_y = np.arange(0, feat_height) * self._feat_stride shift_x, shift_y = np.meshgrid(shift_x, shift_y) shifts = torch.from_numpy( np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel())).transpose()) shifts = shifts.contiguous().type_as(scores).float() A = self._num_anchors K = shifts.size(0) self._anchors = self._anchors.type_as(scores) # anchors = self._anchors.view(1, A, 4) + shifts.view(1, K, 4).permute(1, 0, 2).contiguous() anchors = self._anchors.view(1, A, 4) + shifts.view(K, 1, 4) anchors = anchors.view(1, K * A, 4).expand(batch_size, K * A, 4) # Transpose and reshape predicted bbox transformations to get them # into the same order as the anchors: bbox_deltas = bbox_deltas.permute(0, 2, 3, 1).contiguous() bbox_deltas = bbox_deltas.view(batch_size, -1, 4) # Same story for the scores: scores = scores.permute(0, 2, 3, 1).contiguous() scores = scores.view(batch_size, -1) # Convert anchors into proposals via bbox transformations proposals = bbox_transform_inv(anchors, bbox_deltas, batch_size) # 2. clip predicted boxes to image proposals = clip_boxes(proposals, im_info, batch_size) # proposals = clip_boxes_batch(proposals, im_info, batch_size) # assign the score to 0 if it's non keep. # keep = self._filter_boxes(proposals, min_size * im_info[:, 2]) # trim keep index to make it euqal over batch # keep_idx = torch.cat(tuple(keep_idx), 0) # scores_keep = scores.view(-1)[keep_idx].view(batch_size, trim_size) # proposals_keep = proposals.view(-1, 4)[keep_idx, :].contiguous().view(batch_size, trim_size, 4) # _, order = torch.sort(scores_keep, 1, True) scores_keep = scores proposals_keep = proposals _, order = torch.sort(scores_keep, 1, True) blob = scores.new(batch_size, post_nms_topN, 5).zero_() for i in range(batch_size): # # 3. remove predicted boxes with either height or width < threshold # # (NOTE: convert min_size to input image scale stored in im_info[2]) proposals_single = proposals_keep[i] scores_single = scores_keep[i] # # 4. sort all (proposal, score) pairs by score from highest to lowest # # 5. take top pre_nms_topN (e.g. 6000) order_single = order[i] if pre_nms_topN > 0 and pre_nms_topN < scores_keep.numel(): order_single = order_single[:pre_nms_topN] proposals_single = proposals_single[order_single, :] scores_single = scores_single[order_single].view(-1, 1) # 6. apply nms (e.g. threshold = 0.7) # 7. take after_nms_topN (e.g. 300) # 8. return the top proposals (-> RoIs top) keep_idx_i = nms(torch.cat((proposals_single, scores_single), 1), nms_thresh, force_cpu=not cfg.USE_GPU_NMS) keep_idx_i = keep_idx_i.long().view(-1) if post_nms_topN > 0: keep_idx_i = keep_idx_i[:post_nms_topN] proposals_single = proposals_single[keep_idx_i, :] scores_single = scores_single[keep_idx_i, :] # padding 0 at the end. num_proposal = proposals_single.size(0) blob[i, :, 0] = i blob[i, :num_proposal, 1:] = proposals_single return blob
if vis: im = cv2.imread(imdb.image_path_at(i)) im2show = np.copy(im) # jth class for j in range(imdb.num_classes): inds = torch.nonzero(scores[:, j] > thresh).view(-1) if inds.numel() > 0: cls_scores = scores[:, j][inds] _, order = torch.sort(cls_scores, 0, True) cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4] # N x 5 cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1) cls_dets = cls_dets[order] keep = nms(cls_dets, cfg.TEST.NMS) cls_dets = cls_dets[keep.view(-1).long()] if vis and j != 0: im2show = vis_detections(im2show, imdb.classes[j], cls_dets.cpu().numpy(), 0.3) # num_class , num_images all_boxes[j][i] = cls_dets.cpu().numpy() else: all_boxes[j][i] = empty_array # limit to max_object detections over all classes if max_object > 0: image_scores = np.hstack( [all_boxes[j][i][:, -1] for j in range(imdb.num_classes)]) if len(image_scores) > max_object: image_thresh = np.sort(image_scores)[-max_object]
def test(): args = parse_args() # perpare data print('load data') if args.dataset == 'voc07test': dataset_name = 'voc_2007_test' elif args.dataset == 'voc12test': dataset_name = 'voc_2012_test' else: raise NotImplementedError cfg.TRAIN.USE_FLIPPED = False imdb, roidb = combined_roidb(dataset_name) test_dataset = RoiDataset(roidb) test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) test_data_iter = iter(test_dataloader) # load model model = FasterRCNN(backbone=args.backbone) model_name = '0712_faster_rcnn101_epoch_{}.pth'.format(args.check_epoch) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) model_path = os.path.join(args.output_dir, model_name) model.load_state_dict(torch.load(model_path)['model']) if args.use_gpu: model = model.cuda() model.eval() num_images = len(imdb.image_index) det_file_path = os.path.join(args.output_dir, 'detections.pkl') all_boxes = [[[] for _ in range(num_images)] for _ in range(imdb.num_classes)] empty_array = np.transpose(np.array([[], [], [], [], []]), (1, 0)) torch.set_grad_enabled(False) for i in range(num_images): start_time = time.time() im_data, gt_boxes, im_info = next(test_data_iter) if args.use_gpu: im_data = im_data.cuda() gt_boxes = gt_boxes.cuda() im_info = im_info.cuda() im_data_variable = Variable(im_data) det_tic = time.time() rois, faster_rcnn_cls_prob, faster_rcnn_reg, _, _, _, _, _ = model( im_data_variable, gt_boxes, im_info) scores = faster_rcnn_cls_prob.data boxes = rois.data[:, 1:] boxes_deltas = faster_rcnn_reg.data if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED: boxes_deltas = boxes_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \ + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda() boxes_deltas = boxes_deltas.view(-1, 4 * imdb.num_classes) pred_boxes = bbox_transform_inv_cls(boxes, boxes_deltas) pred_boxes = clip_boxes_cls(pred_boxes, im_info[0]) pred_boxes /= im_info[0][2].item() det_toc = time.time() detect_time = det_tic - det_toc nms_tic = time.time() if args.vis: im_show = Image.open(imdb.image_path_at(i)) for j in range(1, imdb.num_classes): inds = torch.nonzero(scores[:, j] > args.thresh).view(-1) if inds.numel() > 0: cls_score = scores[:, j][inds] _, order = torch.sort(cls_score, 0, True) cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4] cls_dets = torch.cat((cls_boxes, cls_score.unsqueeze(1)), 1) cls_dets = cls_dets[order] keep = nms(cls_dets, 0.3) cls_dets = cls_dets[keep.view(-1).long()] if args.vis: cls_name_dets = np.repeat(j, cls_dets.size(0)) im_show = draw_detection_boxes(im_show, cls_dets.cpu().numpy(), cls_name_dets, imdb.classes, 0.5) all_boxes[j][i] = cls_dets.cpu().numpy() else: all_boxes[j][i] = empty_array if args.max_per_image > 0: image_scores = np.hstack( [all_boxes[j][i][:, -1] for j in range(1, imdb.num_classes)]) if len(image_scores) > args.max_per_image: image_thresh = np.sort(image_scores)[-args.max_per_image] for j in range(1, imdb.num_classes): keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0] all_boxes[j][i] = all_boxes[j][i][keep, :] if args.vis: plt.imshow(im_show) plt.show() nms_toc = time.time() nms_time = nms_tic - nms_toc sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s \r' \ .format(i + 1, num_images, detect_time, nms_time)) sys.stdout.flush() with open(det_file_path, 'wb') as f: pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) print('Evaluating detections') imdb.evaluate_detections(all_boxes, args.output_dir) end_time = time.time() print("test time: %0.4fs" % (end_time - start_time))