Beispiel #1
0
def main():
    parser = ArgumentParser()
    parser.add_argument('img', type=str, help='Input Image file.')
    parser.add_argument('out_file',
                        type=str,
                        help='Output file name of the visualized image.')
    parser.add_argument('--det-config',
                        type=str,
                        default='./configs/textdet/psenet/'
                        'psenet_r50_fpnf_600e_icdar2015.py',
                        help='Text detection config file.')
    parser.add_argument('--det-ckpt',
                        type=str,
                        default='https://download.openmmlab.com/'
                        'mmocr/textdet/psenet/'
                        'psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth',
                        help='Text detection checkpint file (local or url).')
    parser.add_argument('--recog-config',
                        type=str,
                        default='./configs/textrecog/sar/'
                        'sar_r31_parallel_decoder_academic.py',
                        help='Text recognition config file.')
    parser.add_argument('--recog-ckpt',
                        type=str,
                        default='https://download.openmmlab.com/'
                        'mmocr/textrecog/sar/'
                        'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
                        help='Text recognition checkpint file (local or url).')
    parser.add_argument('--batch-mode',
                        action='store_true',
                        help='Whether use batch mode for text recognition.')
    parser.add_argument('--batch-size',
                        type=int,
                        default=4,
                        help='Batch size for text recognition inference '
                        'if batch_mode is True above.')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='Device used for inference.')
    parser.add_argument('--imshow',
                        action='store_true',
                        help='Whether show image with OpenCV.')
    parser.add_argument('--ocr-in-lines',
                        action='store_true',
                        help='Whether group ocr results in lines.')
    args = parser.parse_args()

    if args.device == 'cpu':
        args.device = None
    # build detect model
    detect_model = init_detector(args.det_config,
                                 args.det_ckpt,
                                 device=args.device)
    if hasattr(detect_model, 'module'):
        detect_model = detect_model.module
    if detect_model.cfg.data.test['type'] == 'ConcatDataset':
        detect_model.cfg.data.test.pipeline = \
            detect_model.cfg.data.test['datasets'][0].pipeline

    # build recog model
    recog_model = init_detector(args.recog_config,
                                args.recog_ckpt,
                                device=args.device)
    if hasattr(recog_model, 'module'):
        recog_model = recog_model.module
    if recog_model.cfg.data.test['type'] == 'ConcatDataset':
        recog_model.cfg.data.test.pipeline = \
            recog_model.cfg.data.test['datasets'][0].pipeline

    det_recog_result = det_and_recog_inference(args, detect_model, recog_model)
    print(f'result: {det_recog_result}')
    mmcv.dump(det_recog_result,
              args.out_file + '.json',
              ensure_ascii=False,
              indent=4)

    if args.ocr_in_lines:
        res = det_recog_result['result']
        res = stitch_boxes_into_lines(res, 10, 0.5)
        det_recog_result['result'] = res
        mmcv.dump(det_recog_result,
                  args.out_file + '.line.json',
                  ensure_ascii=False,
                  indent=4)

    img = det_recog_show_result(args.img, det_recog_result)
    mmcv.imwrite(img, args.out_file)
    if args.imshow:
        mmcv.imshow(img, 'predicted results')
Beispiel #2
0
    def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
        end2end_res = []
        # Find bounding boxes in the images (text detection)
        det_result = self.single_inference(det_model, self.args.arrays,
                                           self.args.batch_mode,
                                           self.args.det_batch_size)
        bboxes_list = [res['boundary_result'] for res in det_result]

        if kie_model:
            kie_dataset = KIEDataset(
                dict_file=kie_model.cfg.data.test.dict_file)

        # For each bounding box, the image is cropped and
        # sent to the recognition model either one by one
        # or all together depending on the batch_mode
        for filename, arr, bboxes, out_file in zip(self.args.filenames,
                                                   self.args.arrays,
                                                   bboxes_list,
                                                   self.args.output):
            img_e2e_res = {}
            img_e2e_res['filename'] = filename
            img_e2e_res['result'] = []
            box_imgs = []
            for bbox in bboxes:
                box_res = {}
                box_res['box'] = [round(x) for x in bbox[:-1]]
                box_res['box_score'] = float(bbox[-1])
                box = bbox[:8]
                if len(bbox) > 9:
                    min_x = min(bbox[0:-1:2])
                    min_y = min(bbox[1:-1:2])
                    max_x = max(bbox[0:-1:2])
                    max_y = max(bbox[1:-1:2])
                    box = [
                        min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
                    ]
                box_img = crop_img(arr, box)
                if self.args.batch_mode:
                    box_imgs.append(box_img)
                else:
                    if recog_model == 'Tesseract_recog':
                        recog_result = self.single_inference(recog_model,
                                                             box_img,
                                                             batch_mode=True)
                    else:
                        recog_result = model_inference(recog_model, box_img)
                    text = recog_result['text']
                    text_score = recog_result['score']
                    if isinstance(text_score, list):
                        text_score = sum(text_score) / max(1, len(text))
                    box_res['text'] = text
                    box_res['text_score'] = text_score
                img_e2e_res['result'].append(box_res)

            if self.args.batch_mode:
                recog_results = self.single_inference(
                    recog_model, box_imgs, True, self.args.recog_batch_size)
                for i, recog_result in enumerate(recog_results):
                    text = recog_result['text']
                    text_score = recog_result['score']
                    if isinstance(text_score, (list, tuple)):
                        text_score = sum(text_score) / max(1, len(text))
                    img_e2e_res['result'][i]['text'] = text
                    img_e2e_res['result'][i]['text_score'] = text_score

            if self.args.merge:
                img_e2e_res['result'] = stitch_boxes_into_lines(
                    img_e2e_res['result'], self.args.merge_xdist, 0.5)

            if kie_model:
                annotations = copy.deepcopy(img_e2e_res['result'])
                # Customized for kie_dataset, which
                # assumes that boxes are represented by only 4 points
                for i, ann in enumerate(annotations):
                    min_x = min(ann['box'][::2])
                    min_y = min(ann['box'][1::2])
                    max_x = max(ann['box'][::2])
                    max_y = max(ann['box'][1::2])
                    annotations[i]['box'] = [
                        min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
                    ]
                ann_info = kie_dataset._parse_anno_info(annotations)
                ann_info['ori_bboxes'] = ann_info.get('ori_bboxes',
                                                      ann_info['bboxes'])
                ann_info['gt_bboxes'] = ann_info.get('gt_bboxes',
                                                     ann_info['bboxes'])
                kie_result, data = model_inference(
                    kie_model,
                    arr,
                    ann=ann_info,
                    return_data=True,
                    batch_mode=self.args.batch_mode)
                # visualize KIE results
                self.visualize_kie_output(kie_model,
                                          data,
                                          kie_result,
                                          out_file=out_file,
                                          show=self.args.imshow)
                gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
                labels = self.generate_kie_labels(kie_result, gt_bboxes,
                                                  kie_model.class_list)
                for i in range(len(gt_bboxes)):
                    img_e2e_res['result'][i]['label'] = labels[i][0]
                    img_e2e_res['result'][i]['label_score'] = labels[i][1]

            end2end_res.append(img_e2e_res)
        return end2end_res