def _handle_district(address: str, city=None): if '区' not in address and '县' not in address: return address if '区' in address: key_words = '区' elif '县' in address: key_words = '县' district_info = address.split(key_words)[0] + key_words other_info = address.split(key_words)[1] if not city: # 暂时先不做处理 return address else: district_list = district_info_dict.get(city, None) if district_list: district_tree = bk_tree.BKTree() for district in district_list: district_tree.insert_node(district) best_district = district_tree.search_one(district_info, search_dist=2, min_len=1) if best_district: if best_district != district_info: logger.debug('replace {} to {} by district process'.format( district_info, best_district)) address = best_district + other_info return address
def search_and_fill(pair, structure_items, node_items): match_bg_regex, key = pair medical_bg = None for node in node_items.values(): for rl in match_bg_regex: if re.search(rl, node.text): medical_bg = node break # 找到其右侧的一块区域 if medical_bg: filter_rule = lambda x: len(str_util.only_keep_money_char(x)) > 0 node_in_region = NodeItemGroup.find_node_in_region(medical_bg.bbox, node_items, filter_rule=filter_rule, xoffset=(-1, 3), yoffset=(-2, 2)) if not node_in_region: structure_items[key].content = '' return elif len(node_in_region) == 1: structure_items[key].content = str_util.money_char_clean(node_in_region[0].text) else: fg_node = min(node_in_region, key=lambda x: abs(x.bbox.cy - medical_bg.bbox.cy)) structure_items[key].content = str_util.money_char_clean(fg_node.text) logger.debug('find payinfo {} , value is {}'.format(key, structure_items[key])) else: structure_items[key].content = '0.00'
def process_single(self, item_name: str): logger.info(f'processing {item_name}') # 开始前清理变量 debugger.variables.clean() config = self.config if config.preload_tpl: debugger.variables.resultsTemplate.append({ "name": "tplImage", "text": "模板图", "image": { "data": "#{tplImage}" } }) rec_data_dir = os.path.join(config.work_dir, config.recognition_data_dirname) rec_img_dir = os.path.join(config.work_dir, config.recognition_img_dirname) rec_data_path = os.path.join(rec_data_dir, item_name) + '.json' rec_img_path = os.path.join(rec_img_dir, item_name) + '.jpg' # 读取识别结果数据 with open(rec_data_path, mode='r', encoding='utf-8') as f: rec_data = json.load(f) if not rec_data: logger.warning(f'raw_data is not present. path={rec_data_path}') return rec_img = cv2.imread(rec_img_path) if isinstance(rec_data, list): rec_data = [_convert_ai_rec_data_item(item) for item in rec_data] start_time = time.time() structure_result = self.session.process( rec_data, rec_img, class_name=self.config.class_name, primary_class=self.config.primary_class, secondary_class=self.config.secondary_class, ltrb=False) else: # 新rawdata start_time = time.time() rec_img = cv2.imread(rec_img_path) request, rpc_name = _convert_request(rec_data, rec_img, self.config) # 开始结构化 structure_result = self.request_processor.process( request, rpc_name, self.config.preload_tpl, item_name=item_name) process_duration = time.time() - start_time logger.debug(f'耗时{process_duration}') debugger.variables.structuringDuration = process_duration # 收集结构化结果 self._pack_debug_data(structure_result) self._dump_debug_data(item_name)
def above_offset_by_anchor(self, node_items: Dict[str, TpNodeItem]): match_pairs = [] for above_item in self.above_items.values(): if above_item.is_ban_offset: continue matched_node_item = above_item.match_node(node_items) if matched_node_item is None: continue logger.debug(f'above item {matched_node_item}') matched_node_item.is_above_item = True match_pairs.append((above_item.bbox, matched_node_item.trans_bbox)) if len(match_pairs) == 0: return 0, 0 max_width, max_height = 0, 0 for _, node_bbox in match_pairs: if node_bbox.width > max_width: max_width = node_bbox.width if node_bbox.height > max_height: max_height = node_bbox.height offset = self.search_offset(match_pairs, int(max_height), int(max_width)) return offset
def _pre_func_medical_money(self, item_name: str, passed_nodes: Dict[str, TpNodeItem], node_items: Dict[str, TpNodeItem], img: np.ndarray): def has3number(text): """ 检查一个text是否有三个连续的text :param text: :return: """ if not text: return False res = re.search('[0-9,,,\.]{3,}', text) if res: return True else: return False # TODO 对 那种退现金...退支票...的数据要做特殊处理0001648365 for node in passed_nodes.values(): if has3number(node.text): roi = node.bbox crnn_res, crnn_scores = crnn_util.run_number_amount(img, roi) dis = lcs.llcs(crnn_res, node.text) if dis / len(crnn_res) > 0.6: # 认为crnn_res是有效的 if str_util.only_keep_money_char(crnn_res.replace(',', '')) != \ str_util.only_keep_money_char(node.text.replace(',', '')): logger.debug('re recog {} to crnn , from {} to {}'.format(item_name, node.text, crnn_res)) node.text = crnn_res node.scores = crnn_scores return self._pre_func_money(item_name, passed_nodes, node_items, img)
def _post_regex_filter_amountinwords(self, item_name: str, node_items: Dict[str, TpNodeItem], img: ndarray): for node in node_items.values(): if re.search('[零壹贰叁肆伍陆柒捌玖拾佰仟万亿圆元角分整]+', node.text): logger.debug('recover amount in words {} ...'.format( node.text)) node.is_filtered_by_content = False node.is_filtered_by_area = False node.is_filtered_by_regex = False
def load_from_source_file(source_filepath): bk_tree = BKTree() with open(source_filepath, encoding='utf-8', mode='r') as f: logger.debug('build bktree...') for line in f.readlines(): word = line.strip() word = str_util.str_sbc_2_dbc(word) bk_tree.insert_node(word) logger.debug('finish build bktree...') return bk_tree
def __init__(self, torch_script_model, chars_file): self.converter = LabelConverter(chars_file=chars_file, ctc_invalid_index=0) logger.debug('Load torch script model: %s' % torch_script_model) self.model = torch.jit.load(torch_script_model) self.model.eval() self.input_channel = self.query_input_channel() self.to_tensor = transforms.ToTensor() self.norm = transforms.Normalize([0.5], [0.5])
def search_pay_info(self, structure_items, node_items, img): def search_and_fill(pair, structure_items, node_items): match_bg_regex, key = pair medical_bg = None for node in node_items.values(): for rl in match_bg_regex: if re.search(rl, node.text): medical_bg = node break # 找到其右侧的一块区域 if medical_bg: filter_rule = lambda x: len(str_util.only_keep_money_char(x)) > 0 node_in_region = NodeItemGroup.find_node_in_region(medical_bg.bbox, node_items, filter_rule=filter_rule, xoffset=(-1, 3), yoffset=(-2, 2)) if not node_in_region: structure_items[key].content = '' return elif len(node_in_region) == 1: structure_items[key].content = str_util.money_char_clean(node_in_region[0].text) else: fg_node = min(node_in_region, key=lambda x: abs(x.bbox.cy - medical_bg.bbox.cy)) structure_items[key].content = str_util.money_char_clean(fg_node.text) logger.debug('find payinfo {} , value is {}'.format(key, structure_items[key])) else: structure_items[key].content = '0.00' fill_pair = [ (['^基金支付'], 'medicalpaymoney'), (['^个人账户支付'], 'personal_account_pay_money'), # (['个人支付金额'], 'personpaymoney'), ] for pair in fill_pair: search_and_fill(pair, structure_items, node_items) # 这里加一个特殊的处理逻辑,如果 medicalpaymoney 和 personal_account_pay_money 为0,就直接把personipaymoney 和 大写金额的内容等同出来 if structure_items['medicalpaymoney'].content and structure_items['personal_account_pay_money'].content: if math.isclose(float(structure_items['medicalpaymoney'].content), 0) and math.isclose( float(structure_items['personal_account_pay_money'].content), 0): structure_items['personpaymoney'].content = structure_items['amountinwords'].content logger.debug( 'infer payinfo {} , value is {}'.format('personpaymoney', structure_items['personpaymoney'].content)) else: search_and_fill( (['个人支付金额'], 'personpaymoney'), structure_items, node_items )
def _handle_address(self, structure_items, fg_items, key='address'): """ 处理地址 :param structure_items: :param fg_items: :return: """ address = structure_items[key].content new_address = format_address(address) if new_address != address: logger.debug('change province {} to {}'.format( address, new_address)) structure_items[key].content = new_address
def _pre_func_institution_type(self, item_name: str, passed_nodes: Dict[str, TpNodeItem], node_items: Dict[str, TpNodeItem], img: np.ndarray): for node in passed_nodes.values(): text = node.text redundant_words = ['^医疗机构类型:', '^医疗.*构类型', '^医疗机构类', '^.{2,}机构类型', '^医疗机构类', '^医疗机'] for rule in redundant_words: if re.search(rule, text): text = re.sub(rule, '', text) text = str_util.remove_symbols(text) if node.text != text: logger.debug('institution {} is subtracted from {}'.format(re.sub(rule, '', text), text)) node.text = text
def __init__(self, chars_file, ctc_invalid_index=None): self.chars = ''.join(self.load_chars(chars_file)) # char_set_length + ctc_blank self.num_classes = len(self.chars) + 1 if ctc_invalid_index is None: self.ctc_invalid_index = len(self.chars) else: self.ctc_invalid_index = ctc_invalid_index self.encode_maps = {} self.decode_maps = {} self.create_encode_decode_maps(self.chars) logger.debug('Load chars file: %s num_classes: %d + 1(CTC Black)' % (chars_file, self.num_classes - 1))
def process(self): config = self.config # 读取模板文件 # 如果是基于模板的且开启debug,则读取模板图片 if config.preload_tpl and debugger.enabled: debugger.commonVariables.tplImage = _read_tpl_as_base64( config.class_name) # torch 在多进程模式下,fork 会有问题,见:https://github.com/pytorch/pytorch/issues/17199 # issue 中提到的 set_num_threads 的方法不生效,需要结合 os.environ['OMP_NUM_THREADS'] = '1' 才有用 mp.set_start_method('spawn', force=True) # 删除旧结果目录 result_dir = os.path.join(config.work_dir, config.result_dirname) if os.path.exists(result_dir): shutil.rmtree(result_dir) os.makedirs(result_dir) rec_data_dir = os.path.join(config.work_dir, config.recognition_data_dirname) all_items = [] for rec_data_filename in os.listdir(rec_data_dir): if not rec_data_filename.endswith('.json'): continue item_name, _ = os.path.splitext(rec_data_filename) all_items.append(item_name) if len(all_items) == 0: logger.warning('raw data not found') return # process_count为0表示根据CPU个数确定 # 一般获取到的数字==所在机器的CPU个数,但部署在docker中后,实际可用的CPU个数可能比机器的CPU个数少 # 所以这里使用代码获取真实的CPU个数,而不是给Pool传递None让其自动获取 process_count = config.process_count or _get_available_cpu_count() logger.debug(f'process count is {process_count}') if process_count == 1 or len(all_items) == 1: for item_name in tqdm(all_items, file=sys.stdout): self.process_single(item_name) else: with mp.Pool(process_count, initializer=_process_pool_initializer, initargs=(debugger.enabled, )) as pool: list( tqdm(pool.imap_unordered(self.process_single, all_items), total=len(all_items), file=sys.stdout))
def process( self, node_items: Dict[str, TpNodeItem], img: np.ndarray, class_name: str, debug_data: DebugData = None, ): """ :param node_items: :param img: numpy BGR image :param class_name: 模板名称 :return: dict. key: item_name value: StructureItem.to_dict 的结果 """ # 背景缩放和前景偏移 before_count = len(node_items) # if class_name in ['shanghai_menzhen', 'shanghai_zhuyuan']: # above_offset_method = above_offset.ABOVE_OFFSET_METHOD_ANCHOR # else: # above_offset_method = above_offset.ABOVE_OFFSET_METHOD_IOU self.matchers[class_name].process(node_items, img, debug_data=debug_data) after_count = len(node_items) if before_count != after_count: logger.debug( f"node_items count change after matcher.process(): {before_count} -> {after_count}" ) # 模板匹配,跑 filter_area 和 filter_regex 等 filter # raw_data = [] # for node in node_items.values(): # if getattr(node, 'been_merged', False): # continue # raw_data.append(node.raw_node[0:9]) # variables.add_group('detection', 'detection', raw_data) result = self.parsers[class_name].parse_template(node_items, img, debug_data=debug_data) return result
def _bk_tree_medical_insurance_type(self, structure_items, fg_items): medical_insurance_type_item = structure_items.get('medical_insurance_type', None) if medical_insurance_type_item is None: return if medical_insurance_type_item.content: text = medical_insurance_type_item.content.replace('医保类型:', '') text = text.replace('医保类型', '') text = text.replace('医保类', '') text = text.replace('医疗机构', '') res = bk_tree.medical_insurance_type().search_one(text, search_dist=1, min_len=2) # 对长文本,使用 搜索距离为1有时候难以获得好的搜索结果: if res is None and len(text) > 8: # 再捞一下 res = bk_tree.medical_insurance_type().search_one(text, search_dist=2, min_len=2) if res is None: # 尝试使用非结构化的方法进行搜寻 node_items = fg_items['medical_insurance_type'].node_items_backup config = {'left': ['医保类型', '医疗机构类型'], 'right': ['社会保障卡号'], 'up': None, 'down': None } search_res = NodeItemGroup.get_possible_node(node_items, config, thresh_x=8, thresh_y=4, match_count=2) if search_res: search_res = [res[0] for res in search_res] bk_tree_search_res = [] for idx in range(len(search_res)): new_res = bk_tree.medical_insurance_type().search_one(search_res[idx], search_dist=2, min_len=2) if new_res: bk_tree_search_res.append(new_res) if bk_tree_search_res: search_res = max(bk_tree_search_res, key=lambda x: len(x)) medical_insurance_type_item.content = search_res medical_insurance_type_item.scores = [1] else: return else: medical_insurance_type_item.content = res medical_insurance_type_item.scores = [1] if text != res: logger.debug('medical_insurance_type bk_tree:') logger.debug('\tOrigin: {}'.format(text)) logger.debug('\tItem in tree: {}'.format(res)) logger.debug('\tItem in tree: {}'.format(res))
def _horizontal_merge_match( self, node_items: Dict[str, TpNodeItem]) -> List[TpNodeItem]: norm_match_res = self._norm_match(node_items) if len(norm_match_res) != 0: return norm_match_res # 如果是有编辑距离的设置,则只使用 text 的内容进行 merge if not isinstance(self.content, str): bg_content = self.content["text"] else: bg_content = self.content candidate_node_items = {} candidate_chars_count = 0 for node_item in node_items.values(): if node_item.cn_text in bg_content: candidate_node_items[node_item.uid] = node_item candidate_chars_count += len(node_item.cn_text) # 候选的节点的总长度小于背景的 content 长度,直接返回 if candidate_chars_count < len(bg_content): return [] line_groups = NodeItemGroup.find_row_lines(candidate_node_items) grouped_segs: List[List[NodeItemGroup]] = [] for group in line_groups: segs = group.find_x_segs() grouped_segs.append(segs) out = [] for segs in grouped_segs: for seg in segs: if seg.content() == bg_content: new_node = TpNodeItem(seg.gen_raw_node()) out.append(new_node) logger.debug( f"Merge mode bg item match success: {new_node}") return out
def _bk_tree_beijing_hospital_code_in_serialnumber(self, structure_items): serialnumber_item = structure_items.get('serialnumber', None) if serialnumber_item is None: return serialnumber = serialnumber_item.content if serialnumber is None: return old_hospital_code = serialnumber[:8] correct_word = bk_tree.beijing_hospital_code().search_one(old_hospital_code, search_dist=1, min_len=7) if correct_word is None: return if old_hospital_code != correct_word: logger.debug('Serial number hospital code bk tree look up:') logger.debug('\tOrigin: {}'.format(old_hospital_code)) logger.debug('\tItem in tree: {}'.format(correct_word)) new_serial_num = correct_word + serialnumber[8:] serialnumber_item.content = new_serial_num
def _handle_province(address: str) -> str: if '省' in address: province_info = address.split('省')[0] other_info = ''.join(address.split('省')[1:]) best_province = bk_tree.province().search_one(province_info, search_dist=2, min_len=2) # new_other_info = _handle_city(other_info, best_province) new_other_info = other_info if new_other_info != other_info: logger.debug( 'change city info {} to {} in province process'.format( other_info, new_other_info)) other_info = new_other_info if best_province: address = best_province + other_info else: logger.debug( 'search bk tree failed , org text is {}'.format(province_info)) return address
def eval( self, node_items: Dict[str, TpNodeItem], img_height: int, img_width: int, debug_data: DebugData = None, ): # 背景元素匹配 bg_match_pairs = [] bg_match_all = {} new_node_items = {} bg_nodes_count = 0 for bg_item in self.bg_items.values(): matched_node_items = bg_item.match_node(node_items) if len(matched_node_items) == 0: continue elif len(matched_node_items) == 1: matched_node_item = matched_node_items[0] matched_node_item.is_bg_item = True # merge/split 模式可能会产生新的 node_item,添加到输入中 if matched_node_item.uid not in node_items: new_node_items[matched_node_item.uid] = matched_node_item bg_match_pairs.append((bg_item.uid, matched_node_item.uid)) if len(matched_node_items) > 0: bg_nodes_count += 1 bg_match_all[bg_item.uid] = matched_node_items if (bg_nodes_count <= 1 and self.template_info is not None and abs(img_height / img_width - self.template_ratio) < 0.2): return self.resize_by_2d_scale(node_items, img_height, img_width) node_items.update(new_node_items) # 背景缩放 # self.scale_by_best_match(node_items, bg_match_pairs, bg_match_all) if self.check_do_perspective(node_items, bg_match_pairs, img_height, img_width): H, mask = self.scale_by_perspective(node_items, bg_match_pairs) status = self.check_if_is_normal(node_items) if not status: logger.debug("use perspective fail , use normal method") self.scale_by_best_match(node_items, bg_match_pairs, bg_match_all) if H is not None: logger.debug("Do bg scale by perspective") if debug_data: debug_data.set_H(H) debugger.variables.set_H(H) else: logger.debug("Do bg scale by best match") self.scale_by_best_match(node_items, bg_match_pairs, bg_match_all)
def load_from_disk(working_dir, tree_name, force=False): """ 从磁盘中读取数据入BK-TREE 1. 先尝试从当前目录下的 .tree 目录下读取之前已经落盘的树文件 2. 如果之前并没有落盘(或者树源文件MD5变化了),则从源文件中生成树并落盘 3. 如果Force为True,则强制重新从源文件中生成树 :param working_dir: 工作目录 :param tree_name: :param force: 是否强行生成 :return: """ logger.debug('Load BK tree: {}'.format(tree_name)) source_dir = os.path.join(working_dir, SOURCE_DIR_NAME) tree_dir = os.path.join(working_dir, TREE_DIR_NAME) source_file_path = os.path.join(source_dir, tree_name) if not os.path.exists(source_file_path): logger.error( 'BK-Tree source file not exist: {}'.format(source_file_path)) raise AssertionError( 'BK-Tree source file not exist: {}'.format(source_file_path)) if not os.path.exists(tree_dir): os.makedirs(tree_dir) file_md5 = utils.md5(source_file_path) tree_file_path = os.path.join(tree_dir, file_md5 + '.json') if force: logger.debug('Force create a new tree') bk_tree = load_from_source_file(source_file_path) save_to_tree_file(bk_tree, tree_file_path) elif not os.path.exists(tree_file_path): # 树文件不存在代表树之前没有落过盘 logger.debug('Create a new tree since there is no cache') bk_tree = load_from_source_file(source_file_path) save_to_tree_file(bk_tree, tree_file_path) else: logger.debug('Load tree from tree file: %s' % tree_file_path) bk_tree = load_from_tree_file(tree_file_path) return bk_tree
def _pre_func_crnn_num_bigeng( self, item_name: str, passed_nodes: Dict[str, TpNodeItem], node_items: Dict[str, TpNodeItem], img: np.ndarray, ): for node in passed_nodes.values(): roi = node.bbox crnn_res, scores = self._crnn_util.run_number_capital_eng( img, roi.rect) if crnn_res != node.text: logger.debug( "item_name: {} crnn_num_bigeng:".format(item_name)) logger.debug("\tOrigin: {}".format(node)) logger.debug("\tCRNN result: {}".format(crnn_res)) if crnn_res is not None: node.text = crnn_res node.scores = scores
def _pre_func_shanghai_paymoney_crnn(self, item_name: str, passed_nodes: Dict[str, TpNodeItem], node_items: Dict[str, TpNodeItem], img: np.ndarray): for node in passed_nodes.values(): if not str_util.contain_continue_nums(node.text, 1): continue if node.bbox.height > node.bbox.width: continue crnn_res, scores = crnn_util.run_shanghai_paymoney(img, node.bbox) if crnn_res != node.text: logger.debug('item_name: {}:'.format(item_name)) logger.debug('\tOrigin: {}'.format(node.text)) logger.debug('\tCRNN: {}'.format(crnn_res)) if crnn_res is not None: node.text = crnn_res node.scores = scores
def eval(self, node_items: Dict[str, TpNodeItem], img_height: int, img_width: int, offset_method: str = ABOVE_OFFSET_METHOD_IOU): """ 修改 node_item 中的 trans_box 坐标 :param node_items: :param offset_method: - anchor - iou :return: """ if len(self.above_items) == 0: return ban_offset_uids = self.sign_ban_offset(node_items, self.above_items) if offset_method == ABOVE_OFFSET_METHOD_ANCHOR: new_offset = self.above_offset_by_anchor(node_items) logger.debug('Do above_offset by ANCHOR') elif offset_method == ABOVE_OFFSET_METHOD_IOU: content_similarity = self.cal_content_similarity(node_items) new_offset = self.iou_block_search(content_similarity, node_items, ban_offset_uids) logger.debug('Do above_offset by IOU') else: raise NotImplementedError( f'above_offset offset_method [{offset_method}] is not implemented' ) logger.debug(f'above_offset value: {new_offset}') for uid, node_item in node_items.items(): if uid in ban_offset_uids: continue bb = node_item.trans_bbox.transform( (-new_offset[0], -new_offset[1])) node_item.above_offset_bbox = bb
def _get_beijing_hospital_name_by_hospital_code(self, structure_items, fg_items): serialnumber_fg_item = fg_items.get('serialnumber', None) if serialnumber_fg_item is None: return serialnumber_item = structure_items.get('serialnumber', None) hospital_name_item = structure_items.get('hospital_name', None) if serialnumber_item is None or hospital_name_item is None: return serialnumber_labels = [] regex_failed_serialnumbers_tp_rects = serialnumber_fg_item.regex_failed_tp_rects for rect_data in regex_failed_serialnumbers_tp_rects: serialnumber_labels.append((rect_data.text, rect_data.scores)) serialnumber = serialnumber_item.content if serialnumber is not None: serialnumber_labels.insert(0, (serialnumber, serialnumber_item.scores)) for serial_num, scores in serialnumber_labels: serial_num = str_util.filter_num(serial_num) if len(serial_num) < 8: continue hospital_code = serial_num[:8] hospital_name = beijing_hospital_code_name_map.get(hospital_code, None) if hospital_name is None: continue logger.debug('Look up beijing hospital name:') logger.debug('\tCode: {}'.format(hospital_code)) logger.debug('\tName: {}'.format(hospital_name)) hospital_name_item.content = hospital_name hospital_name_item.scores = serialnumber_item.scores[:8] break
def _bk_tree_shanghai_medical_institution_type(self, structure_items, fg_items): medical_institution_type_item = structure_items.get('medical_institution_type') if medical_institution_type_item is None: return text = medical_institution_type_item.content length = len(text) if length >= 5: search_dist = 2 else: search_dist = 1 bk_res = bk_tree.medical_institution_type().search_one(text, search_dist=search_dist, min_len=2) if bk_res != medical_institution_type_item.content: logger.debug('medical_institution_type bk_tree:') logger.debug('\tOrigin: {}'.format(text)) logger.debug('\tItem in tree: {}'.format(bk_res)) medical_institution_type_item.content = bk_res
def _bk_tree_beijing_medical_institution_type(self, structure_items, fg_items): medical_institution_type_item = structure_items.get('medical_institution_type', None) if "专科医院" in medical_institution_type_item.content: structure_items['medical_institution_type'].content = "专科医院" return if medical_institution_type_item.content.startswith('专科') or \ medical_institution_type_item.content.startswith('专利'): structure_items['medical_institution_type'].content = "专科医院" return if not medical_institution_type_item.content: # 说明现在拿不到medical_insurance_type , 考虑可能是结构化过程中的问题 # ,尝试非结构化方法 node_items = fg_items['medical_institution_type'].node_items_backup bg_node_regex = ['^医疗结构类型$', '^医疗.*构类型', "疗机构类型"] bg_node = None for node in node_items.values(): is_bg_node = False for rl in bg_node_regex: if re.search(rl, node.text): is_bg_node = True if is_bg_node: bg_node = node break # 遍历所有的node,如果这个node在medical周围而且能够在bk_tree中找到类似的结果,则把这个结果搞进去 # fg 为 bg_rect 的右侧的一个区域 if bg_node: bg_xmin, bg_ymin, bg_xmax, bg_ymax = bg_node.bbox.rect bg_height = (bg_ymax - bg_ymin) bg_width = (bg_xmax - bg_xmin) fg_xmin = bg_xmax fg_xmax = fg_xmin + bg_width * 3 fg_ymin = bg_ymin - bg_height fg_ymax = bg_ymax + bg_height fg_rect = [int(fg_xmin), int(fg_ymin), int(fg_xmax), int(fg_ymax)] prob_res = [] for node in node_items.values(): if node.bbox.is_center_in(fg_rect): if len(node.text) <= 1: continue if str_util.keep_num_char(node.text) == node.text and len(node.text) <= 4: continue if node.text == '年' or node.text == '日': continue prob_res.append(node.text) for idx, res in enumerate(prob_res): bk_res = bk_tree.medical_institution_type().search_one(res, search_dist=2, min_len=2) prob_res[idx] = bk_res filter_prob_res = list(filter(lambda x: x is not None, prob_res)) if not filter_prob_res: return if len(filter_prob_res) == 1: medical_institution_type_item.content = filter_prob_res[0] return else: medical_institution_type_item.content = max(filter_prob_res, key=lambda x: len(x)) return else: # 顺便尝试搜索一下常见的医疗机构类型,避免bg没有被找到的情况 fg_node_regex = { "综合医院": ["^综合医院$", "类综合医院", "类型综合医院"], "中医医院": ["医疗.{1,2}中医医院", "类中医医院"], "社区卫生服务中心": ["社区卫生服务中心"], "中西医结合医院": ["中西医结合医院"], "非营利综合医院": ["非营利综合医院"], "对外中医": ['对外中医$'] } for node in node_items.values(): for institution_type in fg_node_regex: for rl in fg_node_regex[institution_type]: if re.search(rl, node.text): medical_institution_type_item.content = institution_type return return if medical_institution_type_item.content: text = medical_institution_type_item.content if text.startswith('型'): medical_institution_type_item.content = medical_institution_type_item.content[1:] return redundant_words = ['^医疗机构类型:', '^医疗.*构类型', '^医疗机构类', '^.{2,}机构类型', ] for rule in redundant_words: if re.search(rule, text): logger.debug('institution {} is subtracted from {}'.format(re.sub(rule, '', text), text)) text = re.sub(rule, '', text) text = str_util.remove_redundant_patthern(text, redundant_words) res = bk_tree.medical_institution_type().search_one(text, search_dist=1, min_len=2) res_2 = None if res is None: # 尝试救一下: # 做一些特殊的对应,因为有的字样只有可能出现在某些特殊的内容中,比如 '外中' special_map = {'外中': '对外中医', '外综': '对外综合', '利中': '非营利中医医院', '性医院': '综合性医院', '性医疗': '非营利性医疗机构' } text_copy = text for sp in special_map: if sp in text_copy: text_copy = special_map[sp] break if text_copy != text: res_2 = bk_tree.medical_institution_type().search_one(text_copy, search_dist=1, min_len=2) if res is None and res_2 is None: medical_institution_type_item.content = text medical_institution_type_item.scores = [0] else: if res is not None: medical_institution_type_item.content = res elif res_2 is not None: medical_institution_type_item.content = res_2 medical_institution_type_item.scores = [1] if text != res: logger.debug('medical_institution_type bk_tree:') logger.debug('\tOrigin: {}'.format(text)) logger.debug('\tItem in tree: {}'.format(res))
def print_info(): logger.debug('Get receiptno from barcode:') logger.debug('\torigin receiptno: {}'.format(origin_receiptno)) logger.debug('\tbarcode: {}'.format(barcode)) logger.debug('\treceiptno from barcode: {}'.format(receiptno_from_barcode))
def _post_func_crnn_date( self, item_name: str, passed_nodes: Dict[str, TpNodeItem], node_items: Dict[str, TpNodeItem], img: np.ndarray, ): """ 对日期字段的候选框进行 merge,然后使用 carnn 进行重识别 """ # TO-DO,viz now is not supported # DEBUG_VIZ = False # 对这种情况做处理 '2017-11-28_9:20:38日' origin_text = "" for node in passed_nodes.values(): origin_text += node.text # 首先,clean 掉所有的 可能的类似于 含有时分秒的问题,'2017-11-28_9:20:38' changed_res = date_util.filter_year_month_day(passed_nodes) if changed_res: for row in changed_res: logger.debug( f"[{item_name}] change {row[0]} to {row[1]} due to it has sec and hour info" ) # 这里应该是要过滤掉所有的数字部分过多的,一看就是序列号一类的node # 这个方法会对如同 '2017._00___82___84付___2017' 的情形过滤掉 passed_nodes = OrderedDict( filter( lambda x: str_util.count_max_continous_num(x[1].text) < 8, passed_nodes.items(), )) if len(passed_nodes) == 0: return # 将在一行上的bbox找出 nodes_in_line = NodeItemGroup.find_row_lines(passed_nodes, y_thresh=0.6) # 先把有number的过滤出来 nodes_in_line = list( filter( lambda x: date_util.contain_number(x.content()) or "年" in x. content() or "日" in x.content(), nodes_in_line, )) # crnn 识别区域为,y的范围参考有数字的部分的区域,x的范围为所有区域的xmin,xmax 的最值 if len(nodes_in_line) == 0: return crnn_xmin = min([node.bbox.rect[0] for node in nodes_in_line]) crnn_xmax = max([node.bbox.rect[2] for node in nodes_in_line]) if len(nodes_in_line) == 1: nodes_in_line = nodes_in_line[0] else: # 选择合并后宽度最长,且有数字的那一个作为nodes_in_line来获得数字 nodes_in_line = list( filter(lambda x: date_util.contain_number(x.content()), nodes_in_line)) nodes_in_line = max(nodes_in_line, key=lambda x: x.bbox.width) crnn_ymin = nodes_in_line.bbox.rect[1] crnn_ymax = nodes_in_line.bbox.rect[3] # 重识别的 rect 坐标要用原图坐标 crnn_res, scores = self._crnn_util.run_number_space( img, [crnn_xmin, crnn_ymin, crnn_xmax, crnn_ymax]) crnn_res, scores = self.check_crnn_res_and_rerecognize( crnn_res, scores, img, crnn_xmin, crnn_ymin, crnn_xmax, crnn_ymax) date_info = [split for split in crnn_res.split("_") if split] useful_info = date_util.get_useful_info(crnn_res.replace("_", " ")) if len(date_info) >= 3: # 如果node 能够形成完美的格式,即年份识别出四个字算1 , 如果识别出两个字,算1/2 split_res = date_util.parse_date(crnn_res, min_year=2000, max_year=2040) if split_res.most_possible_item: res = split_res.most_possible_item.to_string() logger.debug( f"{item_name} res {crnn_res} is formated to {res}") else: res = None crnn_scores = str_util.date_format_scores( *date_info[:3]) # 现在只对前三的内容做处理 if len(date_info) < 3 or res is None: logger.debug( f"{item_name} res {crnn_res} can not format , use normal method" ) res = date_util.get_format_data_from_crnn_num_model_res( crnn_res.replace("_", "")) crnn_scores = str_util.date_format_scores( useful_info["useful_year"], useful_info["useful_month"], useful_info["useful_day"], ) # 如果 crnn 的结果失败了,再使用原来的方法跑一边 if res is None or not date_util.is_legal_format_date_str(res): nodes = NodeItemGroup.recover_node_item_dict( nodes_in_line.node_items) split_res = self._post_func_date(item_name, nodes, node_items, img) if split_res is not None: # recover year of split_res split_res_date = datetime.strptime(split_res[0], "%Y-%m-%d") if not date_util.is_legal_year(split_res_date.year): predict_year = NodeItemGroup.get_year_predict(node_items) predict_date = str( split_res_date.replace(year=predict_year).date()) split_res = (predict_date, split_res[1]) recover_from_crnn = date_util.recover_info_from_useful_info( split_res, useful_info) logger.debug( "recover from crnn , org is {} , recover res is {}".format( split_res, recover_from_crnn)) split_res = recover_from_crnn else: if (useful_info.get("useful_month", None) is not None and useful_info.get("useful_day", None) is not None): # 尝试从所有的 node_items 中找到前三年的相关信息: max_count_year = None max_count = -1 cur_year = datetime.now().year for year in range(cur_year - 3, cur_year + 1): count_of_year = NodeItemGroup.regex_filter_count( node_items, str(year)) if count_of_year > max_count: max_count = count_of_year max_count_year = year if max_count != -1: # find a year predict_date = dt.date( year=max_count_year, month=int(useful_info["useful_month"]), day=int(useful_info["useful_day"]), ) predict_date = str(predict_date) # 尝试从use_ful_info 中恢复部分数据 split_res = (predict_date, [1]) logger.debug("[{}] predict date from crnn , {}".format( item_name, predict_date)) logger.debug("[{}] Origin texts: {}," "CRNN result: {}," "Merge result: {}," "Split result: {}".format(item_name, origin_text, crnn_res, res, split_res)) # name = item_name + str(self.global_test) # viz_crnn_debug(passed_nodes, nodes_in_line, img,name,origin_text) # self.global_test+=1 return split_res # TODO: get_format_data_from_crnn_num_model_res 中需要返回到底用了哪几个位置的数字 mean_score = NodeItemGroup.cal_mean_score(nodes_in_line) # return res, [mean_score] return res, crnn_scores
def _init_fg_items(self, class_name: str, conf: Dict) -> Dict[str, FGItem]: """ 处理函数加载优先级: 1. 配置文件配置了 item_pre_func/item_post_func,寻找指定的函数 2. 配置文件没有配置 item_pre_func/item_post_func,按照约定找:_pre_func_{item_name} / _post_func_{item_name} - 预处理:如果不存在则没有预处理函数 - 后处理:如果不存在则使用 _post_func_max_w_regex Args: class_name: conf: Returns: """ if "fg_items" not in conf: raise ConfigException(f"fg_items not exist in {class_name}.yml") res = {} for item in conf["fg_items"]: item_name = item.get("item_name", None) item_show_name = item.get("show_name", None) if item_name is None: raise ConfigException( f"[{class_name}] FG item [{item_name}] miss [item_name] key" ) item_pre_func = self._get_pre_post_func(item, "item_pre_func", class_name) item_post_func = self._get_pre_post_func(item, "item_post_func", class_name) # 如果模板没有填后处理函数,则默认使用 _post_func_max_w_regex 作为后处理函数 if item_post_func is None: item_post_func = self._get_item_func("_post_func_max_w_regex") # 获得 filters if self.is_tp_conf: filter_areas = [] if "area" in item: filter_areas.append({ "area": item["area"], "w": 1, "ioo_thresh": item.get("ioo_thresh", 0), }) if len(filter_areas) == 0: filter_areas = None filter_regexs = item.get("filter_regexs", []) for r in filter_regexs: r["w"] = 1 filter_confs = { "filter_areas": filter_areas, "filter_contents": item.get("filter_contents", None), "filter_regexs": filter_regexs, } else: filter_confs = { "filter_areas": item.get("filter_areas", None), "filter_contents": item.get("filter_contents", None), "filter_regexs": item.get("filter_regexs", None), } # 获取 search_strategy 的参数 search_strategy = None search_strategy_item = item.get("search_strategy", None) if search_strategy_item is not None: search_mode = search_strategy_item.get("mode", None) if search_mode is None: s = f"[{class_name}] FG item [{item_name}] set search_strategy, but not set [mode] key" raise ConfigException(s) w_pad = search_strategy_item.get("w_pad", 0) h_pad = search_strategy_item.get("h_pad", 0) logger.debug( f"Enlarge {class_name} [{item_name}] search area: w_pad {w_pad} h_pad {h_pad}" ) search_strategy = EnlargeSearchStrategy(w_pad, h_pad) post_regex_filter_func_name = item.get( "item_post_regex_filter_func") post_regex_filter_func = self._get_item_func( post_regex_filter_func_name) res[item_name] = FGItem( item_name, item_show_name, filter_confs, item_pre_func, item_post_func, post_regex_filter_func, should_output=item.get("should_output", True), search_strategy=search_strategy, ) return res
def iou_block_search(self, content_similarity: Dict[str, Dict[str, float]], node_items: Dict[str, TpNodeItem], ban_offset_uids: Set): w, h = 0, 0 for above_item in self.above_items.values(): w = max(w, above_item.bbox.right) h = max(h, above_item.bbox.bottom) for node_item in node_items.values(): w = max(w, node_item.trans_bbox.right) h = max(h, node_item.trans_bbox.bottom) # w,h 为所有的框的上限位置,为搜索的上届,做二分法搜索 now_offset_center = [0, 0] max_search_dis = max(w, h) / 6.0 split_size = 10 record_of_all_best = None while max_search_dis > 1: block_size = max_search_dis / split_size left_up_point = [ v - (max_search_dis / 2.0) for v in now_offset_center ] new_center = [0, 0] result_offset_center = [v for v in now_offset_center] best_similarity_value = 0 for i in range(split_size): for j in range(split_size): new_center[0] = left_up_point[0] + i * block_size new_center[1] = left_up_point[1] + j * block_size similarity_value, hit_num, hit_can_not_miss = self.cal_match_similarity_value( content_similarity, node_items, new_center, ban_offset_uids) offset = [v for v in new_center] if not record_of_all_best: record_of_all_best = (hit_can_not_miss, hit_num, similarity_value, offset) else: # 现在里面已经有值了 current = (hit_can_not_miss, hit_num, similarity_value, offset) if self.get_score_of_each_offest( current) > self.get_score_of_each_offest( record_of_all_best): record_of_all_best = current # 原先的处理逻辑 # if similarity_value > best_similarity_value: # best_similarity_value = similarity_value # result_offset_center = [v for v in new_center] # now_offset_center = [v for v in result_offset_center] max_search_dis = block_size * 2 # if self.exp_data: # self.exp_data.set_above_matched_idxes(best_label_node_idx_4_debug) logger.debug( 'above offset with hit_can_not_miss :{} , with hit {}'.format( record_of_all_best[0], record_of_all_best[1])) return record_of_all_best[3]