Exemple #1
0
def test_kie_dataset():
    tmp_dir = tempfile.TemporaryDirectory()
    # create dummy data
    ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
    ann_info1 = _create_dummy_ann_file(ann_file)

    dict_file = osp.join(tmp_dir.name, 'fake_dict.txt')
    _create_dummy_dict_file(dict_file)

    # test initialization
    loader = _create_dummy_loader()
    dataset = KIEDataset(ann_file, loader, dict_file, pipeline=[])

    tmp_dir.cleanup()

    # test pre_pipeline
    img_info = dataset.data_infos[0]
    results = dict(img_info=img_info)
    dataset.pre_pipeline(results)
    assert results['img_prefix'] == dataset.img_prefix

    # test _parse_anno_info
    annos = ann_info1['annotations']
    with pytest.raises(AssertionError):
        dataset._parse_anno_info(annos[0])
    tmp_annos = [{
        'text': 'store',
        'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
    }]
    with pytest.raises(AssertionError):
        dataset._parse_anno_info(tmp_annos)

    return_anno = dataset._parse_anno_info(annos)
    assert 'bboxes' in return_anno
    assert 'relations' in return_anno
    assert 'texts' in return_anno
    assert 'labels' in return_anno

    # test evaluation
    result = {}
    result['nodes'] = torch.full((5, 5), 1, dtype=torch.float)
    result['nodes'][:, 1] = 100.
    print('hello', result['nodes'].size())
    results = [result for _ in range(5)]

    eval_res = dataset.evaluate(results)
    assert math.isclose(eval_res['macro_f1'], 0.2, abs_tol=1e-4)
Exemple #2
0
 def kiedataset_with_test_dict(**kwargs):
     kwargs['dict_file'] = 'tests/data/kie_toy_dataset/dict.txt'
     return KIEDataset(**kwargs)
Exemple #3
0
    def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
        end2end_res = []
        # Find bounding boxes in the images (text detection)
        det_result = self.single_inference(det_model, self.args.arrays,
                                           self.args.batch_mode,
                                           self.args.det_batch_size)
        bboxes_list = [res['boundary_result'] for res in det_result]

        if kie_model:
            kie_dataset = KIEDataset(
                dict_file=kie_model.cfg.data.test.dict_file)

        # For each bounding box, the image is cropped and
        # sent to the recognition model either one by one
        # or all together depending on the batch_mode
        for filename, arr, bboxes, out_file in zip(self.args.filenames,
                                                   self.args.arrays,
                                                   bboxes_list,
                                                   self.args.output):
            img_e2e_res = {}
            img_e2e_res['filename'] = filename
            img_e2e_res['result'] = []
            box_imgs = []
            for bbox in bboxes:
                box_res = {}
                box_res['box'] = [round(x) for x in bbox[:-1]]
                box_res['box_score'] = float(bbox[-1])
                box = bbox[:8]
                if len(bbox) > 9:
                    min_x = min(bbox[0:-1:2])
                    min_y = min(bbox[1:-1:2])
                    max_x = max(bbox[0:-1:2])
                    max_y = max(bbox[1:-1:2])
                    box = [
                        min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
                    ]
                box_img = crop_img(arr, box)
                if self.args.batch_mode:
                    box_imgs.append(box_img)
                else:
                    if recog_model == 'Tesseract_recog':
                        recog_result = self.single_inference(recog_model,
                                                             box_img,
                                                             batch_mode=True)
                    else:
                        recog_result = model_inference(recog_model, box_img)
                    text = recog_result['text']
                    text_score = recog_result['score']
                    if isinstance(text_score, list):
                        text_score = sum(text_score) / max(1, len(text))
                    box_res['text'] = text
                    box_res['text_score'] = text_score
                img_e2e_res['result'].append(box_res)

            if self.args.batch_mode:
                recog_results = self.single_inference(
                    recog_model, box_imgs, True, self.args.recog_batch_size)
                for i, recog_result in enumerate(recog_results):
                    text = recog_result['text']
                    text_score = recog_result['score']
                    if isinstance(text_score, (list, tuple)):
                        text_score = sum(text_score) / max(1, len(text))
                    img_e2e_res['result'][i]['text'] = text
                    img_e2e_res['result'][i]['text_score'] = text_score

            if self.args.merge:
                img_e2e_res['result'] = stitch_boxes_into_lines(
                    img_e2e_res['result'], self.args.merge_xdist, 0.5)

            if kie_model:
                annotations = copy.deepcopy(img_e2e_res['result'])
                # Customized for kie_dataset, which
                # assumes that boxes are represented by only 4 points
                for i, ann in enumerate(annotations):
                    min_x = min(ann['box'][::2])
                    min_y = min(ann['box'][1::2])
                    max_x = max(ann['box'][::2])
                    max_y = max(ann['box'][1::2])
                    annotations[i]['box'] = [
                        min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
                    ]
                ann_info = kie_dataset._parse_anno_info(annotations)
                ann_info['ori_bboxes'] = ann_info.get('ori_bboxes',
                                                      ann_info['bboxes'])
                ann_info['gt_bboxes'] = ann_info.get('gt_bboxes',
                                                     ann_info['bboxes'])
                kie_result, data = model_inference(
                    kie_model,
                    arr,
                    ann=ann_info,
                    return_data=True,
                    batch_mode=self.args.batch_mode)
                # visualize KIE results
                self.visualize_kie_output(kie_model,
                                          data,
                                          kie_result,
                                          out_file=out_file,
                                          show=self.args.imshow)
                gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
                labels = self.generate_kie_labels(kie_result, gt_bboxes,
                                                  kie_model.class_list)
                for i in range(len(gt_bboxes)):
                    img_e2e_res['result'][i]['label'] = labels[i][0]
                    img_e2e_res['result'][i]['label_score'] = labels[i][1]

            end2end_res.append(img_e2e_res)
        return end2end_res