def nms_detections(pred_boxes, scores, nms_thresh, inds=None):
    dets =, scores.unsqueeze(1)), 1)
    keep = nms(dets, nms_thresh).long().view(-1)
    if inds is None:
        return pred_boxes[keep], scores[keep]
    return pred_boxes[keep], scores[keep], inds[keep]
    def forward(self, input):

        # Algorithm:
        # for each (H, W) location i
        #   generate A anchor boxes centered on cell i
        #   apply predicted bbox deltas at cell i to each of the A anchors
        # clip predicted boxes to image
        # remove predicted boxes with either height or width < threshold
        # sort all (proposal, score) pairs by score from highest to lowest
        # take top pre_nms_topN proposals before NMS
        # apply NMS with threshold 0.7 to remaining proposals
        # take after_nms_topN proposals after NMS
        # return the top proposals (-> RoIs top, scores top)
        # layer_params = yaml.load(self.param_str_)
        # the first set of _num_anchors channels are bg probs
        # the second set are the fg probs
        scores = input[0][:, self._num_anchors:, :, :]
        bbox_deltas = input[1]
        im_info = input[2]
        cfg_key = input[3]

        #assert rpn_cls_prob_reshape.shape[0] == 1, \
        #    'Only single item batches are supported'
        # cfg_key = str(self.phase) # either 'TRAIN' or 'TEST'
        # cfg_key = 'TEST'
        pre_nms_topN = cfg[cfg_key].RPN_PRE_NMS_TOP_N
        post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N
        nms_thresh = cfg[cfg_key].RPN_NMS_THRESH
        min_size = cfg[cfg_key].RPN_MIN_SIZE

        batch_size = bbox_deltas.size(0)

        feat_height, feat_width = scores.size(2), scores.size(3)
        shift_x = np.arange(0, feat_width) * self._feat_stride
        shift_y = np.arange(0, feat_height) * self._feat_stride
        shift_x, shift_y = np.meshgrid(shift_x, shift_y)
        shifts = torch.from_numpy(
            np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(),
        shifts = shifts.contiguous().type_as(scores).float()

        A = self._num_anchors
        K = shifts.size(0)

        self._anchors = self._anchors.type_as(scores)
        # anchors = self._anchors.view(1, A, 4) + shifts.view(1, K, 4).permute(1, 0, 2).contiguous()
        anchors = self._anchors.view(1, A, 4) + shifts.view(K, 1, 4)
        anchors = anchors.view(1, K * A, 4).expand(batch_size, K * A, 4)

        # Transpose and reshape predicted bbox transformations to get them
        # into the same order as the anchors:

        bbox_deltas = bbox_deltas.permute(0, 2, 3, 1).contiguous()
        bbox_deltas = bbox_deltas.view(batch_size, -1, 4)

        # Same story for the scores:
        scores = scores.permute(0, 2, 3, 1).contiguous()
        scores = scores.view(batch_size, -1)

        # Convert anchors into proposals via bbox transformations
        proposals = bbox_transform_inv(anchors, bbox_deltas, batch_size)

        # 2. clip predicted boxes to image
        proposals = clip_boxes(proposals, im_info, batch_size)
        # proposals = clip_boxes_batch(proposals, im_info, batch_size)

        # assign the score to 0 if it's non keep.
        # keep = self._filter_boxes(proposals, min_size * im_info[:, 2])

        # trim keep index to make it euqal over batch
        # keep_idx =, 0)

        # scores_keep = scores.view(-1)[keep_idx].view(batch_size, trim_size)
        # proposals_keep = proposals.view(-1, 4)[keep_idx, :].contiguous().view(batch_size, trim_size, 4)

        # _, order = torch.sort(scores_keep, 1, True)

        scores_keep = scores
        proposals_keep = proposals
        _, order = torch.sort(scores_keep, 1, True)

        blob =, post_nms_topN, 5).zero_()
        for i in range(batch_size):
            # # 3. remove predicted boxes with either height or width < threshold
            # # (NOTE: convert min_size to input image scale stored in im_info[2])
            proposals_single = proposals_keep[i]
            scores_single = scores_keep[i]

            # # 4. sort all (proposal, score) pairs by score from highest to lowest
            # # 5. take top pre_nms_topN (e.g. 6000)
            order_single = order[i]

            if pre_nms_topN > 0 and pre_nms_topN < scores_keep.numel():
                order_single = order_single[:pre_nms_topN]

            proposals_single = proposals_single[order_single, :]
            scores_single = scores_single[order_single].view(-1, 1)

            # 6. apply nms (e.g. threshold = 0.7)
            # 7. take after_nms_topN (e.g. 300)
            # 8. return the top proposals (-> RoIs top)

            keep_idx_i = nms(, scores_single), 1),
                             force_cpu=not cfg.USE_GPU_NMS)
            keep_idx_i = keep_idx_i.long().view(-1)

            if post_nms_topN > 0:
                keep_idx_i = keep_idx_i[:post_nms_topN]
            proposals_single = proposals_single[keep_idx_i, :]
            scores_single = scores_single[keep_idx_i, :]

            # padding 0 at the end.
            num_proposal = proposals_single.size(0)
            blob[i, :, 0] = i
            blob[i, :num_proposal, 1:] = proposals_single

        return blob
            if vis:
                im = cv2.imread(imdb.image_path_at(i))
                im2show = np.copy(im)
            # jth class
            for j in range(imdb.num_classes):
                inds = torch.nonzero(scores[:, j] > thresh).view(-1)
                if inds.numel() > 0:
                    cls_scores = scores[:, j][inds]
                    _, order = torch.sort(cls_scores, 0, True)
                    cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                    # N x 5
                    cls_dets =, cls_scores.unsqueeze(1)),
                    cls_dets = cls_dets[order]
                    keep = nms(cls_dets, cfg.TEST.NMS)
                    cls_dets = cls_dets[keep.view(-1).long()]
                    if vis and j != 0:
                        im2show = vis_detections(im2show, imdb.classes[j],
                                                 cls_dets.cpu().numpy(), 0.3)
                    # num_class , num_images
                    all_boxes[j][i] = cls_dets.cpu().numpy()
                    all_boxes[j][i] = empty_array

            # limit to max_object detections over all classes
            if max_object > 0:
                image_scores = np.hstack(
                    [all_boxes[j][i][:, -1] for j in range(imdb.num_classes)])
                if len(image_scores) > max_object:
                    image_thresh = np.sort(image_scores)[-max_object]
def test():
    args = parse_args()

    # perpare data
    print('load data')
    if args.dataset == 'voc07test':
        dataset_name = 'voc_2007_test'
    elif args.dataset == 'voc12test':
        dataset_name = 'voc_2012_test'
        raise NotImplementedError

    cfg.TRAIN.USE_FLIPPED = False

    imdb, roidb = combined_roidb(dataset_name)
    test_dataset = RoiDataset(roidb)
    test_dataloader = DataLoader(dataset=test_dataset,
    test_data_iter = iter(test_dataloader)

    # load model
    model = FasterRCNN(backbone=args.backbone)
    model_name = '0712_faster_rcnn101_epoch_{}.pth'.format(args.check_epoch)
    if not os.path.exists(args.output_dir):
    model_path = os.path.join(args.output_dir, model_name)

    if args.use_gpu:
        model = model.cuda()


    num_images = len(imdb.image_index)
    det_file_path = os.path.join(args.output_dir, 'detections.pkl')

    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    empty_array = np.transpose(np.array([[], [], [], [], []]), (1, 0))


    for i in range(num_images):
        start_time = time.time()
        im_data, gt_boxes, im_info = next(test_data_iter)
        if args.use_gpu:
            im_data = im_data.cuda()
            gt_boxes = gt_boxes.cuda()
            im_info = im_info.cuda()

        im_data_variable = Variable(im_data)

        det_tic = time.time()
        rois, faster_rcnn_cls_prob, faster_rcnn_reg, _, _, _, _, _ = model(
            im_data_variable, gt_boxes, im_info)

        scores =
        boxes =[:, 1:]
        boxes_deltas =

            boxes_deltas = boxes_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                         + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
            boxes_deltas = boxes_deltas.view(-1, 4 * imdb.num_classes)

        pred_boxes = bbox_transform_inv_cls(boxes, boxes_deltas)

        pred_boxes = clip_boxes_cls(pred_boxes, im_info[0])

        pred_boxes /= im_info[0][2].item()

        det_toc = time.time()
        detect_time = det_tic - det_toc
        nms_tic = time.time()

        if args.vis:
            im_show =

        for j in range(1, imdb.num_classes):
            inds = torch.nonzero(scores[:, j] > args.thresh).view(-1)

            if inds.numel() > 0:
                cls_score = scores[:, j][inds]
                _, order = torch.sort(cls_score, 0, True)
                cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                cls_dets =, cls_score.unsqueeze(1)), 1)
                cls_dets = cls_dets[order]
                keep = nms(cls_dets, 0.3)
                cls_dets = cls_dets[keep.view(-1).long()]
                if args.vis:
                    cls_name_dets = np.repeat(j, cls_dets.size(0))
                    im_show = draw_detection_boxes(im_show,
                                                   cls_name_dets, imdb.classes,
                all_boxes[j][i] = cls_dets.cpu().numpy()
                all_boxes[j][i] = empty_array

        if args.max_per_image > 0:
            image_scores = np.hstack(
                [all_boxes[j][i][:, -1] for j in range(1, imdb.num_classes)])
            if len(image_scores) > args.max_per_image:
                image_thresh = np.sort(image_scores)[-args.max_per_image]
                for j in range(1, imdb.num_classes):
                    keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
                    all_boxes[j][i] = all_boxes[j][i][keep, :]

        if args.vis:
        nms_toc = time.time()
        nms_time = nms_tic - nms_toc
        sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s   \r' \
                         .format(i + 1, num_images, detect_time, nms_time))

    with open(det_file_path, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    imdb.evaluate_detections(all_boxes, args.output_dir)

    end_time = time.time()
    print("test time: %0.4fs" % (end_time - start_time))