コード例 #1
0
def main(args):
    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    # test a single image
    model_results = model_inference(model, args.img)
    model.show_result(
        args.img,
        model_results,
        win_name='model_results',
        show=True,
        score_thr=args.score_thr)
    url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
    with open(args.img, 'rb') as image:
        response = requests.post(url, image)
    serve_results = response.json()
    model.show_result(
        args.img,
        serve_results,
        show=True,
        win_name='serve_results',
        score_thr=args.score_thr)
    assert serve_results.keys() == model_results.keys()
    for key in serve_results.keys():
        for model_result, serve_result in zip(model_results[key],
                                              serve_results[key]):
            if isinstance(model_result[0], (int, float)):
                assert np.allclose(model_result, serve_result)
            elif isinstance(model_result[0], str):
                assert model_result == serve_result
            else:
                raise TypeError
コード例 #2
0
ファイル: webcam_demo.py プロジェクト: xyzhu8/mmocr
def main():
    args = parse_args()

    device = torch.device(args.device)

    model = init_detector(args.config, args.checkpoint, device=device)
    if model.cfg.data.test['type'] == 'ConcatDataset':
        model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
            0].pipeline

    camera = cv2.VideoCapture(args.camera_id)

    print('Press "Esc", "q" or "Q" to exit.')
    while True:
        ret_val, img = camera.read()
        result = model_inference(model, img)

        ch = cv2.waitKey(1)
        if ch == 27 or ch == ord('q') or ch == ord('Q'):
            break

        model.show_result(img,
                          result,
                          score_thr=args.score_thr,
                          wait_time=1,
                          show=True)
コード例 #3
0
ファイル: det_test_imgs.py プロジェクト: Pandinosaurus/mmocr
def main():
    parser = ArgumentParser()
    parser.add_argument('img_root', type=str, help='Image root path')
    parser.add_argument('img_list', type=str, help='Image path list file')
    parser.add_argument('config', type=str, help='Config file')
    parser.add_argument('checkpoint', type=str, help='Checkpoint file')
    parser.add_argument('--score-thr',
                        type=float,
                        default=0.5,
                        help='Bbox score threshold')
    parser.add_argument('--out-dir',
                        type=str,
                        default='./results',
                        help='Dir to save '
                        'visualize images '
                        'and bbox')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='Device used for inference.')
    args = parser.parse_args()

    assert 0 < args.score_thr < 1

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    if hasattr(model, 'module'):
        model = model.module
    if model.cfg.data.test['type'] == 'ConcatDataset':
        model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
            0].pipeline

    # Start Inference
    out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
    mmcv.mkdir_or_exist(out_vis_dir)
    out_txt_dir = osp.join(args.out_dir, 'out_txt_dir')
    mmcv.mkdir_or_exist(out_txt_dir)

    lines = list_from_file(args.img_list)
    progressbar = ProgressBar(task_num=len(lines))
    for line in lines:
        progressbar.update()
        img_path = osp.join(args.img_root, line.strip())
        if not osp.exists(img_path):
            raise FileNotFoundError(img_path)
        # Test a single image
        result = model_inference(model, img_path)
        img_name = osp.basename(img_path)
        # save result
        save_results(result, out_txt_dir, img_name, score_thr=args.score_thr)
        # show result
        out_file = osp.join(out_vis_dir, img_name)
        kwargs_dict = {
            'score_thr': args.score_thr,
            'show': False,
            'out_file': out_file
        }
        model.show_result(img_path, result, **kwargs_dict)

    print(f'\nInference done, and results saved in {args.out_dir}\n')
コード例 #4
0
ファイル: recog_test_imgs.py プロジェクト: open-mmlab/mmocr
def main():
    parser = ArgumentParser()
    parser.add_argument('img_root_path', type=str, help='Image root path')
    parser.add_argument('img_list', type=str, help='Image path list file')
    parser.add_argument('config', type=str, help='Config file')
    parser.add_argument('checkpoint', type=str, help='Checkpoint file')
    parser.add_argument(
        '--out-dir', type=str, default='./results', help='Dir to save results')
    parser.add_argument(
        '--show', action='store_true', help='show image or save')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference.')
    args = parser.parse_args()

    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(args.out_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level='INFO')

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    if hasattr(model, 'module'):
        model = model.module

    # Start Inference
    out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
    mmcv.mkdir_or_exist(out_vis_dir)
    correct_vis_dir = osp.join(args.out_dir, 'correct')
    mmcv.mkdir_or_exist(correct_vis_dir)
    wrong_vis_dir = osp.join(args.out_dir, 'wrong')
    mmcv.mkdir_or_exist(wrong_vis_dir)
    img_paths, pred_labels, gt_labels = [], [], []

    lines = list_from_file(args.img_list)
    progressbar = ProgressBar(task_num=len(lines))
    num_gt_label = 0
    for line in lines:
        progressbar.update()
        item_list = line.strip().split()
        img_file = item_list[0]
        gt_label = ''
        if len(item_list) >= 2:
            gt_label = item_list[1]
            num_gt_label += 1
        img_path = osp.join(args.img_root_path, img_file)
        if not osp.exists(img_path):
            raise FileNotFoundError(img_path)
        # Test a single image
        result = model_inference(model, img_path)
        pred_label = result['text']

        out_img_name = '_'.join(img_file.split('/'))
        out_file = osp.join(out_vis_dir, out_img_name)
        kwargs_dict = {
            'gt_label': gt_label,
            'show': args.show,
            'out_file': '' if args.show else out_file
        }
        model.show_result(img_path, result, **kwargs_dict)
        if gt_label != '':
            if gt_label == pred_label:
                dst_file = osp.join(correct_vis_dir, out_img_name)
            else:
                dst_file = osp.join(wrong_vis_dir, out_img_name)
            shutil.copy(out_file, dst_file)
        img_paths.append(img_path)
        gt_labels.append(gt_label)
        pred_labels.append(pred_label)

    # Save results
    save_results(img_paths, pred_labels, gt_labels, args.out_dir)

    if num_gt_label == len(pred_labels):
        # eval
        eval_results = eval_ocr_metric(pred_labels, gt_labels)
        logger.info('\n' + '-' * 100)
        info = ('eval on testset with img_root_path '
                f'{args.img_root_path} and img_list {args.img_list}\n')
        logger.info(info)
        logger.info(eval_results)

    print(f'\nInference done, and results saved in {args.out_dir}\n')
コード例 #5
0
ファイル: mmocr_handler.py プロジェクト: Pandinosaurus/mmocr
    def inference(self, data, *args, **kwargs):

        results = model_inference(self.model, data)
        return results