def test_kie_dataset(): tmp_dir = tempfile.TemporaryDirectory() # create dummy data ann_file = osp.join(tmp_dir.name, 'fake_data.txt') ann_info1 = _create_dummy_ann_file(ann_file) dict_file = osp.join(tmp_dir.name, 'fake_dict.txt') _create_dummy_dict_file(dict_file) # test initialization loader = _create_dummy_loader() dataset = KIEDataset(ann_file, loader, dict_file, pipeline=[]) tmp_dir.cleanup() # test pre_pipeline img_info = dataset.data_infos[0] results = dict(img_info=img_info) dataset.pre_pipeline(results) assert results['img_prefix'] == dataset.img_prefix # test _parse_anno_info annos = ann_info1['annotations'] with pytest.raises(AssertionError): dataset._parse_anno_info(annos[0]) tmp_annos = [{ 'text': 'store', 'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0] }] with pytest.raises(AssertionError): dataset._parse_anno_info(tmp_annos) return_anno = dataset._parse_anno_info(annos) assert 'bboxes' in return_anno assert 'relations' in return_anno assert 'texts' in return_anno assert 'labels' in return_anno # test evaluation result = {} result['nodes'] = torch.full((5, 5), 1, dtype=torch.float) result['nodes'][:, 1] = 100. print('hello', result['nodes'].size()) results = [result for _ in range(5)] eval_res = dataset.evaluate(results) assert math.isclose(eval_res['macro_f1'], 0.2, abs_tol=1e-4)
def kiedataset_with_test_dict(**kwargs): kwargs['dict_file'] = 'tests/data/kie_toy_dataset/dict.txt' return KIEDataset(**kwargs)
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None): end2end_res = [] # Find bounding boxes in the images (text detection) det_result = self.single_inference(det_model, self.args.arrays, self.args.batch_mode, self.args.det_batch_size) bboxes_list = [res['boundary_result'] for res in det_result] if kie_model: kie_dataset = KIEDataset( dict_file=kie_model.cfg.data.test.dict_file) # For each bounding box, the image is cropped and # sent to the recognition model either one by one # or all together depending on the batch_mode for filename, arr, bboxes, out_file in zip(self.args.filenames, self.args.arrays, bboxes_list, self.args.output): img_e2e_res = {} img_e2e_res['filename'] = filename img_e2e_res['result'] = [] box_imgs = [] for bbox in bboxes: box_res = {} box_res['box'] = [round(x) for x in bbox[:-1]] box_res['box_score'] = float(bbox[-1]) box = bbox[:8] if len(bbox) > 9: min_x = min(bbox[0:-1:2]) min_y = min(bbox[1:-1:2]) max_x = max(bbox[0:-1:2]) max_y = max(bbox[1:-1:2]) box = [ min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y ] box_img = crop_img(arr, box) if self.args.batch_mode: box_imgs.append(box_img) else: if recog_model == 'Tesseract_recog': recog_result = self.single_inference(recog_model, box_img, batch_mode=True) else: recog_result = model_inference(recog_model, box_img) text = recog_result['text'] text_score = recog_result['score'] if isinstance(text_score, list): text_score = sum(text_score) / max(1, len(text)) box_res['text'] = text box_res['text_score'] = text_score img_e2e_res['result'].append(box_res) if self.args.batch_mode: recog_results = self.single_inference( recog_model, box_imgs, True, self.args.recog_batch_size) for i, recog_result in enumerate(recog_results): text = recog_result['text'] text_score = recog_result['score'] if isinstance(text_score, (list, tuple)): text_score = sum(text_score) / max(1, len(text)) img_e2e_res['result'][i]['text'] = text img_e2e_res['result'][i]['text_score'] = text_score if self.args.merge: img_e2e_res['result'] = stitch_boxes_into_lines( img_e2e_res['result'], self.args.merge_xdist, 0.5) if kie_model: annotations = copy.deepcopy(img_e2e_res['result']) # Customized for kie_dataset, which # assumes that boxes are represented by only 4 points for i, ann in enumerate(annotations): min_x = min(ann['box'][::2]) min_y = min(ann['box'][1::2]) max_x = max(ann['box'][::2]) max_y = max(ann['box'][1::2]) annotations[i]['box'] = [ min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y ] ann_info = kie_dataset._parse_anno_info(annotations) ann_info['ori_bboxes'] = ann_info.get('ori_bboxes', ann_info['bboxes']) ann_info['gt_bboxes'] = ann_info.get('gt_bboxes', ann_info['bboxes']) kie_result, data = model_inference( kie_model, arr, ann=ann_info, return_data=True, batch_mode=self.args.batch_mode) # visualize KIE results self.visualize_kie_output(kie_model, data, kie_result, out_file=out_file, show=self.args.imshow) gt_bboxes = data['gt_bboxes'].data.numpy().tolist() labels = self.generate_kie_labels(kie_result, gt_bboxes, kie_model.class_list) for i in range(len(gt_bboxes)): img_e2e_res['result'][i]['label'] = labels[i][0] img_e2e_res['result'][i]['label_score'] = labels[i][1] end2end_res.append(img_e2e_res) return end2end_res