Пример #1
0
 def det_recog_pp(self, result):
     final_results = []
     args = self.args
     for arr, output, export, det_recog_result in zip(
             args.arrays, args.output, args.export, result):
         if output or args.imshow:
             if self.kie_model:
                 res_img = det_recog_show_result(arr, det_recog_result)
             else:
                 res_img = det_recog_show_result(arr,
                                                 det_recog_result,
                                                 out_file=output)
             if args.imshow and not self.kie_model:
                 mmcv.imshow(res_img, 'inference results')
         if not args.details:
             simple_res = {}
             simple_res['filename'] = det_recog_result['filename']
             simple_res['text'] = [
                 x['text'] for x in det_recog_result['result']
             ]
             final_result = simple_res
         else:
             final_result = det_recog_result
         if export:
             mmcv.dump(final_result, export, indent=4)
         if args.print_result:
             print(final_result, end='\n\n')
         final_results.append(final_result)
     return final_results
Пример #2
0
def main():
    parser = ArgumentParser()
    parser.add_argument('img', type=str, help='Input Image file.')
    parser.add_argument('out_file',
                        type=str,
                        help='Output file name of the visualized image.')
    parser.add_argument('--det-config',
                        type=str,
                        default='./configs/textdet/psenet/'
                        'psenet_r50_fpnf_600e_icdar2015.py',
                        help='Text detection config file.')
    parser.add_argument('--det-ckpt',
                        type=str,
                        default='https://download.openmmlab.com/'
                        'mmocr/textdet/psenet/'
                        'psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth',
                        help='Text detection checkpint file (local or url).')
    parser.add_argument('--recog-config',
                        type=str,
                        default='./configs/textrecog/sar/'
                        'sar_r31_parallel_decoder_academic.py',
                        help='Text recognition config file.')
    parser.add_argument('--recog-ckpt',
                        type=str,
                        default='https://download.openmmlab.com/'
                        'mmocr/textrecog/sar/'
                        'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
                        help='Text recognition checkpint file (local or url).')
    parser.add_argument('--batch-mode',
                        action='store_true',
                        help='Whether use batch mode for text recognition.')
    parser.add_argument('--batch-size',
                        type=int,
                        default=4,
                        help='Batch size for text recognition inference '
                        'if batch_mode is True above.')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='Device used for inference.')
    parser.add_argument('--imshow',
                        action='store_true',
                        help='Whether show image with OpenCV.')
    parser.add_argument('--ocr-in-lines',
                        action='store_true',
                        help='Whether group ocr results in lines.')
    args = parser.parse_args()

    if args.device == 'cpu':
        args.device = None
    # build detect model
    detect_model = init_detector(args.det_config,
                                 args.det_ckpt,
                                 device=args.device)
    if hasattr(detect_model, 'module'):
        detect_model = detect_model.module
    if detect_model.cfg.data.test['type'] == 'ConcatDataset':
        detect_model.cfg.data.test.pipeline = \
            detect_model.cfg.data.test['datasets'][0].pipeline

    # build recog model
    recog_model = init_detector(args.recog_config,
                                args.recog_ckpt,
                                device=args.device)
    if hasattr(recog_model, 'module'):
        recog_model = recog_model.module
    if recog_model.cfg.data.test['type'] == 'ConcatDataset':
        recog_model.cfg.data.test.pipeline = \
            recog_model.cfg.data.test['datasets'][0].pipeline

    det_recog_result = det_and_recog_inference(args, detect_model, recog_model)
    print(f'result: {det_recog_result}')
    mmcv.dump(det_recog_result,
              args.out_file + '.json',
              ensure_ascii=False,
              indent=4)

    if args.ocr_in_lines:
        res = det_recog_result['result']
        res = stitch_boxes_into_lines(res, 10, 0.5)
        det_recog_result['result'] = res
        mmcv.dump(det_recog_result,
                  args.out_file + '.line.json',
                  ensure_ascii=False,
                  indent=4)

    img = det_recog_show_result(args.img, det_recog_result)
    mmcv.imwrite(img, args.out_file)
    if args.imshow:
        mmcv.imshow(img, 'predicted results')