Ejemplo n.º 1
0
def _handle_district(address: str, city=None):
    if '区' not in address and '县' not in address:
        return address

    if '区' in address:
        key_words = '区'
    elif '县' in address:
        key_words = '县'

    district_info = address.split(key_words)[0] + key_words
    other_info = address.split(key_words)[1]
    if not city:
        # 暂时先不做处理
        return address
    else:
        district_list = district_info_dict.get(city, None)
        if district_list:
            district_tree = bk_tree.BKTree()
            for district in district_list:
                district_tree.insert_node(district)
            best_district = district_tree.search_one(district_info,
                                                     search_dist=2,
                                                     min_len=1)
            if best_district:
                if best_district != district_info:
                    logger.debug('replace {} to {} by district process'.format(
                        district_info, best_district))
                    address = best_district + other_info
    return address
Ejemplo n.º 2
0
        def search_and_fill(pair, structure_items, node_items):
            match_bg_regex, key = pair
            medical_bg = None
            for node in node_items.values():
                for rl in match_bg_regex:
                    if re.search(rl, node.text):
                        medical_bg = node
                        break
            # 找到其右侧的一块区域

            if medical_bg:

                filter_rule = lambda x: len(str_util.only_keep_money_char(x)) > 0
                node_in_region = NodeItemGroup.find_node_in_region(medical_bg.bbox,
                                                                   node_items,
                                                                   filter_rule=filter_rule,
                                                                   xoffset=(-1, 3),
                                                                   yoffset=(-2, 2))
                if not node_in_region:
                    structure_items[key].content = ''
                    return
                elif len(node_in_region) == 1:
                    structure_items[key].content = str_util.money_char_clean(node_in_region[0].text)
                else:
                    fg_node = min(node_in_region, key=lambda x: abs(x.bbox.cy - medical_bg.bbox.cy))
                    structure_items[key].content = str_util.money_char_clean(fg_node.text)
                logger.debug('find payinfo {} , value is {}'.format(key, structure_items[key]))
            else:
                structure_items[key].content = '0.00'
Ejemplo n.º 3
0
    def process_single(self, item_name: str):
        logger.info(f'processing {item_name}')
        # 开始前清理变量
        debugger.variables.clean()
        config = self.config

        if config.preload_tpl:
            debugger.variables.resultsTemplate.append({
                "name": "tplImage",
                "text": "模板图",
                "image": {
                    "data": "#{tplImage}"
                }
            })

        rec_data_dir = os.path.join(config.work_dir,
                                    config.recognition_data_dirname)
        rec_img_dir = os.path.join(config.work_dir,
                                   config.recognition_img_dirname)
        rec_data_path = os.path.join(rec_data_dir, item_name) + '.json'
        rec_img_path = os.path.join(rec_img_dir, item_name) + '.jpg'

        # 读取识别结果数据
        with open(rec_data_path, mode='r', encoding='utf-8') as f:
            rec_data = json.load(f)
        if not rec_data:
            logger.warning(f'raw_data is not present. path={rec_data_path}')
            return
        rec_img = cv2.imread(rec_img_path)

        if isinstance(rec_data, list):
            rec_data = [_convert_ai_rec_data_item(item) for item in rec_data]
            start_time = time.time()
            structure_result = self.session.process(
                rec_data,
                rec_img,
                class_name=self.config.class_name,
                primary_class=self.config.primary_class,
                secondary_class=self.config.secondary_class,
                ltrb=False)
        else:  # 新rawdata
            start_time = time.time()
            rec_img = cv2.imread(rec_img_path)
            request, rpc_name = _convert_request(rec_data, rec_img,
                                                 self.config)
            # 开始结构化

            structure_result = self.request_processor.process(
                request,
                rpc_name,
                self.config.preload_tpl,
                item_name=item_name)

        process_duration = time.time() - start_time
        logger.debug(f'耗时{process_duration}')
        debugger.variables.structuringDuration = process_duration

        # 收集结构化结果
        self._pack_debug_data(structure_result)
        self._dump_debug_data(item_name)
Ejemplo n.º 4
0
    def above_offset_by_anchor(self, node_items: Dict[str, TpNodeItem]):
        match_pairs = []
        for above_item in self.above_items.values():
            if above_item.is_ban_offset:
                continue

            matched_node_item = above_item.match_node(node_items)
            if matched_node_item is None:
                continue

            logger.debug(f'above item {matched_node_item}')
            matched_node_item.is_above_item = True
            match_pairs.append((above_item.bbox, matched_node_item.trans_bbox))

        if len(match_pairs) == 0:
            return 0, 0

        max_width, max_height = 0, 0
        for _, node_bbox in match_pairs:
            if node_bbox.width > max_width:
                max_width = node_bbox.width
            if node_bbox.height > max_height:
                max_height = node_bbox.height

        offset = self.search_offset(match_pairs, int(max_height),
                                    int(max_width))

        return offset
Ejemplo n.º 5
0
    def _pre_func_medical_money(self, item_name: str,
                                passed_nodes: Dict[str, TpNodeItem],
                                node_items: Dict[str, TpNodeItem],
                                img: np.ndarray):
        def has3number(text):
            """
            检查一个text是否有三个连续的text
            :param text:
            :return:
            """
            if not text:
                return False
            res = re.search('[0-9,,,\.]{3,}', text)
            if res:
                return True
            else:
                return False

        # TODO 对 那种退现金...退支票...的数据要做特殊处理0001648365
        for node in passed_nodes.values():
            if has3number(node.text):
                roi = node.bbox
                crnn_res, crnn_scores = crnn_util.run_number_amount(img, roi)
                dis = lcs.llcs(crnn_res, node.text)
                if dis / len(crnn_res) > 0.6:
                    # 认为crnn_res是有效的
                    if str_util.only_keep_money_char(crnn_res.replace(',', '')) != \
                            str_util.only_keep_money_char(node.text.replace(',', '')):
                        logger.debug('re recog {} to crnn , from {} to {}'.format(item_name, node.text, crnn_res))
                    node.text = crnn_res
                    node.scores = crnn_scores

        return self._pre_func_money(item_name, passed_nodes, node_items, img)
Ejemplo n.º 6
0
 def _post_regex_filter_amountinwords(self, item_name: str,
                                      node_items: Dict[str, TpNodeItem],
                                      img: ndarray):
     for node in node_items.values():
         if re.search('[零壹贰叁肆伍陆柒捌玖拾佰仟万亿圆元角分整]+', node.text):
             logger.debug('recover amount in words {} ...'.format(
                 node.text))
             node.is_filtered_by_content = False
             node.is_filtered_by_area = False
             node.is_filtered_by_regex = False
Ejemplo n.º 7
0
def load_from_source_file(source_filepath):
    bk_tree = BKTree()
    with open(source_filepath, encoding='utf-8', mode='r') as f:
        logger.debug('build bktree...')
        for line in f.readlines():
            word = line.strip()
            word = str_util.str_sbc_2_dbc(word)
            bk_tree.insert_node(word)
        logger.debug('finish build bktree...')
    return bk_tree
Ejemplo n.º 8
0
    def __init__(self, torch_script_model, chars_file):
        self.converter = LabelConverter(chars_file=chars_file,
                                        ctc_invalid_index=0)
        logger.debug('Load torch script model: %s' % torch_script_model)
        self.model = torch.jit.load(torch_script_model)

        self.model.eval()
        self.input_channel = self.query_input_channel()

        self.to_tensor = transforms.ToTensor()
        self.norm = transforms.Normalize([0.5], [0.5])
Ejemplo n.º 9
0
    def search_pay_info(self, structure_items, node_items, img):
        def search_and_fill(pair, structure_items, node_items):
            match_bg_regex, key = pair
            medical_bg = None
            for node in node_items.values():
                for rl in match_bg_regex:
                    if re.search(rl, node.text):
                        medical_bg = node
                        break
            # 找到其右侧的一块区域

            if medical_bg:

                filter_rule = lambda x: len(str_util.only_keep_money_char(x)) > 0
                node_in_region = NodeItemGroup.find_node_in_region(medical_bg.bbox,
                                                                   node_items,
                                                                   filter_rule=filter_rule,
                                                                   xoffset=(-1, 3),
                                                                   yoffset=(-2, 2))
                if not node_in_region:
                    structure_items[key].content = ''
                    return
                elif len(node_in_region) == 1:
                    structure_items[key].content = str_util.money_char_clean(node_in_region[0].text)
                else:
                    fg_node = min(node_in_region, key=lambda x: abs(x.bbox.cy - medical_bg.bbox.cy))
                    structure_items[key].content = str_util.money_char_clean(fg_node.text)
                logger.debug('find payinfo {} , value is {}'.format(key, structure_items[key]))
            else:
                structure_items[key].content = '0.00'

        fill_pair = [
            (['^基金支付'], 'medicalpaymoney'),
            (['^个人账户支付'], 'personal_account_pay_money'),
            # (['个人支付金额'], 'personpaymoney'),

        ]

        for pair in fill_pair:
            search_and_fill(pair, structure_items, node_items)
        # 这里加一个特殊的处理逻辑,如果 medicalpaymoney 和 personal_account_pay_money 为0,就直接把personipaymoney 和 大写金额的内容等同出来
        if structure_items['medicalpaymoney'].content and structure_items['personal_account_pay_money'].content:
            if math.isclose(float(structure_items['medicalpaymoney'].content), 0) and math.isclose(
                    float(structure_items['personal_account_pay_money'].content), 0):
                structure_items['personpaymoney'].content = structure_items['amountinwords'].content
                logger.debug(
                    'infer payinfo {} , value is {}'.format('personpaymoney',
                                                            structure_items['personpaymoney'].content))
        else:
            search_and_fill(
                (['个人支付金额'], 'personpaymoney'), structure_items, node_items
            )
Ejemplo n.º 10
0
 def _handle_address(self, structure_items, fg_items, key='address'):
     """
     处理地址
     :param structure_items:
     :param fg_items:
     :return:
     """
     address = structure_items[key].content
     new_address = format_address(address)
     if new_address != address:
         logger.debug('change province {} to {}'.format(
             address, new_address))
         structure_items[key].content = new_address
Ejemplo n.º 11
0
 def _pre_func_institution_type(self,
                                item_name: str,
                                passed_nodes: Dict[str, TpNodeItem],
                                node_items: Dict[str, TpNodeItem],
                                img: np.ndarray):
     for node in passed_nodes.values():
         text = node.text
         redundant_words = ['^医疗机构类型:', '^医疗.*构类型', '^医疗机构类', '^.{2,}机构类型', '^医疗机构类', '^医疗机']
         for rule in redundant_words:
             if re.search(rule, text):
                 text = re.sub(rule, '', text)
         text = str_util.remove_symbols(text)
         if node.text != text:
             logger.debug('institution {} is subtracted from {}'.format(re.sub(rule, '', text), text))
             node.text = text
Ejemplo n.º 12
0
    def __init__(self, chars_file, ctc_invalid_index=None):
        self.chars = ''.join(self.load_chars(chars_file))

        # char_set_length + ctc_blank
        self.num_classes = len(self.chars) + 1
        if ctc_invalid_index is None:
            self.ctc_invalid_index = len(self.chars)
        else:
            self.ctc_invalid_index = ctc_invalid_index

        self.encode_maps = {}
        self.decode_maps = {}

        self.create_encode_decode_maps(self.chars)

        logger.debug('Load chars file: %s num_classes: %d + 1(CTC Black)' %
                     (chars_file, self.num_classes - 1))
Ejemplo n.º 13
0
    def process(self):
        config = self.config
        # 读取模板文件
        # 如果是基于模板的且开启debug,则读取模板图片
        if config.preload_tpl and debugger.enabled:
            debugger.commonVariables.tplImage = _read_tpl_as_base64(
                config.class_name)

        # torch 在多进程模式下,fork 会有问题,见:https://github.com/pytorch/pytorch/issues/17199
        # issue 中提到的 set_num_threads 的方法不生效,需要结合 os.environ['OMP_NUM_THREADS'] = '1' 才有用
        mp.set_start_method('spawn', force=True)
        # 删除旧结果目录
        result_dir = os.path.join(config.work_dir, config.result_dirname)
        if os.path.exists(result_dir):
            shutil.rmtree(result_dir)
        os.makedirs(result_dir)

        rec_data_dir = os.path.join(config.work_dir,
                                    config.recognition_data_dirname)
        all_items = []
        for rec_data_filename in os.listdir(rec_data_dir):
            if not rec_data_filename.endswith('.json'):
                continue
            item_name, _ = os.path.splitext(rec_data_filename)
            all_items.append(item_name)

        if len(all_items) == 0:
            logger.warning('raw data not found')
            return

        # process_count为0表示根据CPU个数确定
        # 一般获取到的数字==所在机器的CPU个数,但部署在docker中后,实际可用的CPU个数可能比机器的CPU个数少
        # 所以这里使用代码获取真实的CPU个数,而不是给Pool传递None让其自动获取
        process_count = config.process_count or _get_available_cpu_count()
        logger.debug(f'process count is {process_count}')
        if process_count == 1 or len(all_items) == 1:
            for item_name in tqdm(all_items, file=sys.stdout):
                self.process_single(item_name)
        else:
            with mp.Pool(process_count,
                         initializer=_process_pool_initializer,
                         initargs=(debugger.enabled, )) as pool:
                list(
                    tqdm(pool.imap_unordered(self.process_single, all_items),
                         total=len(all_items),
                         file=sys.stdout))
Ejemplo n.º 14
0
Archivo: main.py Proyecto: imfifc/myocr
    def process(
        self,
        node_items: Dict[str, TpNodeItem],
        img: np.ndarray,
        class_name: str,
        debug_data: DebugData = None,
    ):
        """
        :param node_items:
        :param img: numpy BGR image
        :param class_name: 模板名称
        :return:
            dict.
            key: item_name
            value: StructureItem.to_dict 的结果
        """
        # 背景缩放和前景偏移
        before_count = len(node_items)
        # if class_name in ['shanghai_menzhen', 'shanghai_zhuyuan']:
        #     above_offset_method = above_offset.ABOVE_OFFSET_METHOD_ANCHOR
        # else:
        #     above_offset_method = above_offset.ABOVE_OFFSET_METHOD_IOU
        self.matchers[class_name].process(node_items,
                                          img,
                                          debug_data=debug_data)

        after_count = len(node_items)
        if before_count != after_count:
            logger.debug(
                f"node_items count change after matcher.process(): {before_count} -> {after_count}"
            )

        # 模板匹配,跑 filter_area 和 filter_regex 等 filter

        # raw_data = []
        # for node in node_items.values():
        #     if getattr(node, 'been_merged', False):
        #         continue
        #     raw_data.append(node.raw_node[0:9])
        # variables.add_group('detection', 'detection', raw_data)

        result = self.parsers[class_name].parse_template(node_items,
                                                         img,
                                                         debug_data=debug_data)
        return result
Ejemplo n.º 15
0
    def _bk_tree_medical_insurance_type(self, structure_items, fg_items):
        medical_insurance_type_item = structure_items.get('medical_insurance_type', None)
        if medical_insurance_type_item is None:
            return

        if medical_insurance_type_item.content:
            text = medical_insurance_type_item.content.replace('医保类型:', '')
            text = text.replace('医保类型', '')
            text = text.replace('医保类', '')
            text = text.replace('医疗机构', '')
            res = bk_tree.medical_insurance_type().search_one(text,
                                                              search_dist=1,
                                                              min_len=2)
            # 对长文本,使用 搜索距离为1有时候难以获得好的搜索结果:
            if res is None and len(text) > 8:
                # 再捞一下
                res = bk_tree.medical_insurance_type().search_one(text,
                                                                  search_dist=2,
                                                                  min_len=2)
            if res is None:
                # 尝试使用非结构化的方法进行搜寻
                node_items = fg_items['medical_insurance_type'].node_items_backup
                config = {'left': ['医保类型', '医疗机构类型'],
                          'right': ['社会保障卡号'],
                          'up': None,
                          'down': None
                          }
                search_res = NodeItemGroup.get_possible_node(node_items, config, thresh_x=8, thresh_y=4, match_count=2)
                if search_res:
                    search_res = [res[0] for res in search_res]
                    bk_tree_search_res = []
                    for idx in range(len(search_res)):
                        new_res = bk_tree.medical_insurance_type().search_one(search_res[idx], search_dist=2, min_len=2)
                        if new_res:
                            bk_tree_search_res.append(new_res)

                    if bk_tree_search_res:
                        search_res = max(bk_tree_search_res, key=lambda x: len(x))
                        medical_insurance_type_item.content = search_res
                        medical_insurance_type_item.scores = [1]
                else:
                    return

            else:
                medical_insurance_type_item.content = res
                medical_insurance_type_item.scores = [1]
                if text != res:
                    logger.debug('medical_insurance_type bk_tree:')
                    logger.debug('\tOrigin: {}'.format(text))
                    logger.debug('\tItem in tree: {}'.format(res))
                    logger.debug('\tItem in tree: {}'.format(res))
Ejemplo n.º 16
0
    def _horizontal_merge_match(
            self, node_items: Dict[str, TpNodeItem]) -> List[TpNodeItem]:
        norm_match_res = self._norm_match(node_items)
        if len(norm_match_res) != 0:
            return norm_match_res

        # 如果是有编辑距离的设置,则只使用 text 的内容进行 merge
        if not isinstance(self.content, str):
            bg_content = self.content["text"]
        else:
            bg_content = self.content

        candidate_node_items = {}
        candidate_chars_count = 0

        for node_item in node_items.values():
            if node_item.cn_text in bg_content:
                candidate_node_items[node_item.uid] = node_item
                candidate_chars_count += len(node_item.cn_text)

        # 候选的节点的总长度小于背景的 content 长度,直接返回
        if candidate_chars_count < len(bg_content):
            return []

        line_groups = NodeItemGroup.find_row_lines(candidate_node_items)

        grouped_segs: List[List[NodeItemGroup]] = []
        for group in line_groups:
            segs = group.find_x_segs()
            grouped_segs.append(segs)

        out = []
        for segs in grouped_segs:
            for seg in segs:
                if seg.content() == bg_content:
                    new_node = TpNodeItem(seg.gen_raw_node())
                    out.append(new_node)
                    logger.debug(
                        f"Merge mode bg item match success: {new_node}")

        return out
Ejemplo n.º 17
0
    def _bk_tree_beijing_hospital_code_in_serialnumber(self, structure_items):
        serialnumber_item = structure_items.get('serialnumber', None)
        if serialnumber_item is None:
            return

        serialnumber = serialnumber_item.content

        if serialnumber is None:
            return

        old_hospital_code = serialnumber[:8]

        correct_word = bk_tree.beijing_hospital_code().search_one(old_hospital_code,
                                                                  search_dist=1,
                                                                  min_len=7)

        if correct_word is None:
            return

        if old_hospital_code != correct_word:
            logger.debug('Serial number hospital code bk tree look up:')
            logger.debug('\tOrigin: {}'.format(old_hospital_code))
            logger.debug('\tItem in tree: {}'.format(correct_word))

        new_serial_num = correct_word + serialnumber[8:]
        serialnumber_item.content = new_serial_num
Ejemplo n.º 18
0
def _handle_province(address: str) -> str:
    if '省' in address:
        province_info = address.split('省')[0]
        other_info = ''.join(address.split('省')[1:])

        best_province = bk_tree.province().search_one(province_info,
                                                      search_dist=2,
                                                      min_len=2)
        # new_other_info = _handle_city(other_info, best_province)
        new_other_info = other_info
        if new_other_info != other_info:
            logger.debug(
                'change city info {} to {} in province process'.format(
                    other_info, new_other_info))
            other_info = new_other_info
        if best_province:
            address = best_province + other_info
        else:
            logger.debug(
                'search bk tree failed , org text is {}'.format(province_info))

    return address
Ejemplo n.º 19
0
    def eval(
        self,
        node_items: Dict[str, TpNodeItem],
        img_height: int,
        img_width: int,
        debug_data: DebugData = None,
    ):
        # 背景元素匹配
        bg_match_pairs = []
        bg_match_all = {}
        new_node_items = {}

        bg_nodes_count = 0
        for bg_item in self.bg_items.values():
            matched_node_items = bg_item.match_node(node_items)

            if len(matched_node_items) == 0:
                continue
            elif len(matched_node_items) == 1:
                matched_node_item = matched_node_items[0]
                matched_node_item.is_bg_item = True

                # merge/split 模式可能会产生新的 node_item,添加到输入中
                if matched_node_item.uid not in node_items:
                    new_node_items[matched_node_item.uid] = matched_node_item

                bg_match_pairs.append((bg_item.uid, matched_node_item.uid))

            if len(matched_node_items) > 0:
                bg_nodes_count += 1
                bg_match_all[bg_item.uid] = matched_node_items

        if (bg_nodes_count <= 1 and self.template_info is not None
                and abs(img_height / img_width - self.template_ratio) < 0.2):
            return self.resize_by_2d_scale(node_items, img_height, img_width)

        node_items.update(new_node_items)

        # 背景缩放
        # self.scale_by_best_match(node_items, bg_match_pairs, bg_match_all)
        if self.check_do_perspective(node_items, bg_match_pairs, img_height,
                                     img_width):
            H, mask = self.scale_by_perspective(node_items, bg_match_pairs)

            status = self.check_if_is_normal(node_items)
            if not status:
                logger.debug("use perspective fail , use normal method")
                self.scale_by_best_match(node_items, bg_match_pairs,
                                         bg_match_all)

            if H is not None:
                logger.debug("Do bg scale by perspective")
                if debug_data:
                    debug_data.set_H(H)
                debugger.variables.set_H(H)
        else:
            logger.debug("Do bg scale by best match")
            self.scale_by_best_match(node_items, bg_match_pairs, bg_match_all)
Ejemplo n.º 20
0
def load_from_disk(working_dir, tree_name, force=False):
    """
    从磁盘中读取数据入BK-TREE
        1. 先尝试从当前目录下的 .tree 目录下读取之前已经落盘的树文件
        2. 如果之前并没有落盘(或者树源文件MD5变化了),则从源文件中生成树并落盘
        3. 如果Force为True,则强制重新从源文件中生成树
    :param working_dir: 工作目录
    :param tree_name:
    :param force: 是否强行生成
    :return:
    """
    logger.debug('Load BK tree: {}'.format(tree_name))
    source_dir = os.path.join(working_dir, SOURCE_DIR_NAME)
    tree_dir = os.path.join(working_dir, TREE_DIR_NAME)
    source_file_path = os.path.join(source_dir, tree_name)

    if not os.path.exists(source_file_path):
        logger.error(
            'BK-Tree source file not exist: {}'.format(source_file_path))
        raise AssertionError(
            'BK-Tree source file not exist: {}'.format(source_file_path))

    if not os.path.exists(tree_dir):
        os.makedirs(tree_dir)

    file_md5 = utils.md5(source_file_path)
    tree_file_path = os.path.join(tree_dir, file_md5 + '.json')

    if force:
        logger.debug('Force create a new tree')
        bk_tree = load_from_source_file(source_file_path)
        save_to_tree_file(bk_tree, tree_file_path)
    elif not os.path.exists(tree_file_path):  # 树文件不存在代表树之前没有落过盘
        logger.debug('Create a new tree since there is no cache')
        bk_tree = load_from_source_file(source_file_path)
        save_to_tree_file(bk_tree, tree_file_path)
    else:
        logger.debug('Load tree from tree file: %s' % tree_file_path)
        bk_tree = load_from_tree_file(tree_file_path)

    return bk_tree
Ejemplo n.º 21
0
    def _pre_func_crnn_num_bigeng(
        self,
        item_name: str,
        passed_nodes: Dict[str, TpNodeItem],
        node_items: Dict[str, TpNodeItem],
        img: np.ndarray,
    ):

        for node in passed_nodes.values():
            roi = node.bbox
            crnn_res, scores = self._crnn_util.run_number_capital_eng(
                img, roi.rect)
            if crnn_res != node.text:
                logger.debug(
                    "item_name: {} crnn_num_bigeng:".format(item_name))
                logger.debug("\tOrigin: {}".format(node))
                logger.debug("\tCRNN result: {}".format(crnn_res))
            if crnn_res is not None:
                node.text = crnn_res
                node.scores = scores
Ejemplo n.º 22
0
    def _pre_func_shanghai_paymoney_crnn(self, item_name: str,
                                         passed_nodes: Dict[str, TpNodeItem],
                                         node_items: Dict[str, TpNodeItem],
                                         img: np.ndarray):
        for node in passed_nodes.values():
            if not str_util.contain_continue_nums(node.text, 1):
                continue

            if node.bbox.height > node.bbox.width:
                continue

            crnn_res, scores = crnn_util.run_shanghai_paymoney(img, node.bbox)

            if crnn_res != node.text:
                logger.debug('item_name: {}:'.format(item_name))
                logger.debug('\tOrigin: {}'.format(node.text))
                logger.debug('\tCRNN: {}'.format(crnn_res))

            if crnn_res is not None:
                node.text = crnn_res
                node.scores = scores
Ejemplo n.º 23
0
    def eval(self,
             node_items: Dict[str, TpNodeItem],
             img_height: int,
             img_width: int,
             offset_method: str = ABOVE_OFFSET_METHOD_IOU):
        """
        修改 node_item 中的 trans_box 坐标
        :param node_items:
        :param offset_method:
            - anchor
            - iou
        :return:
        """
        if len(self.above_items) == 0:
            return

        ban_offset_uids = self.sign_ban_offset(node_items, self.above_items)

        if offset_method == ABOVE_OFFSET_METHOD_ANCHOR:
            new_offset = self.above_offset_by_anchor(node_items)
            logger.debug('Do above_offset by ANCHOR')
        elif offset_method == ABOVE_OFFSET_METHOD_IOU:
            content_similarity = self.cal_content_similarity(node_items)
            new_offset = self.iou_block_search(content_similarity, node_items,
                                               ban_offset_uids)
            logger.debug('Do above_offset by IOU')
        else:
            raise NotImplementedError(
                f'above_offset offset_method [{offset_method}] is not implemented'
            )

        logger.debug(f'above_offset value: {new_offset}')

        for uid, node_item in node_items.items():
            if uid in ban_offset_uids:
                continue

            bb = node_item.trans_bbox.transform(
                (-new_offset[0], -new_offset[1]))
            node_item.above_offset_bbox = bb
Ejemplo n.º 24
0
    def _get_beijing_hospital_name_by_hospital_code(self, structure_items, fg_items):
        serialnumber_fg_item = fg_items.get('serialnumber', None)
        if serialnumber_fg_item is None:
            return

        serialnumber_item = structure_items.get('serialnumber', None)
        hospital_name_item = structure_items.get('hospital_name', None)
        if serialnumber_item is None or hospital_name_item is None:
            return

        serialnumber_labels = []

        regex_failed_serialnumbers_tp_rects = serialnumber_fg_item.regex_failed_tp_rects
        for rect_data in regex_failed_serialnumbers_tp_rects:
            serialnumber_labels.append((rect_data.text, rect_data.scores))

        serialnumber = serialnumber_item.content
        if serialnumber is not None:
            serialnumber_labels.insert(0, (serialnumber, serialnumber_item.scores))

        for serial_num, scores in serialnumber_labels:
            serial_num = str_util.filter_num(serial_num)
            if len(serial_num) < 8:
                continue
            hospital_code = serial_num[:8]

            hospital_name = beijing_hospital_code_name_map.get(hospital_code, None)

            if hospital_name is None:
                continue

            logger.debug('Look up beijing hospital name:')
            logger.debug('\tCode: {}'.format(hospital_code))
            logger.debug('\tName: {}'.format(hospital_name))

            hospital_name_item.content = hospital_name
            hospital_name_item.scores = serialnumber_item.scores[:8]
            break
Ejemplo n.º 25
0
    def _bk_tree_shanghai_medical_institution_type(self, structure_items, fg_items):
        medical_institution_type_item = structure_items.get('medical_institution_type')
        if medical_institution_type_item is None:
            return

        text = medical_institution_type_item.content

        length = len(text)
        if length >= 5:
            search_dist = 2
        else:
            search_dist = 1

        bk_res = bk_tree.medical_institution_type().search_one(text,
                                                               search_dist=search_dist,
                                                               min_len=2)

        if bk_res != medical_institution_type_item.content:
            logger.debug('medical_institution_type bk_tree:')
            logger.debug('\tOrigin: {}'.format(text))
            logger.debug('\tItem in tree: {}'.format(bk_res))

            medical_institution_type_item.content = bk_res
Ejemplo n.º 26
0
    def _bk_tree_beijing_medical_institution_type(self, structure_items, fg_items):
        medical_institution_type_item = structure_items.get('medical_institution_type', None)

        if "专科医院" in medical_institution_type_item.content:
            structure_items['medical_institution_type'].content = "专科医院"
            return

        if medical_institution_type_item.content.startswith('专科') or \
                medical_institution_type_item.content.startswith('专利'):
            structure_items['medical_institution_type'].content = "专科医院"
            return

        if not medical_institution_type_item.content:

            # 说明现在拿不到medical_insurance_type , 考虑可能是结构化过程中的问题
            # ,尝试非结构化方法
            node_items = fg_items['medical_institution_type'].node_items_backup
            bg_node_regex = ['^医疗结构类型$', '^医疗.*构类型', "疗机构类型"]

            bg_node = None
            for node in node_items.values():
                is_bg_node = False
                for rl in bg_node_regex:
                    if re.search(rl, node.text):
                        is_bg_node = True
                if is_bg_node:
                    bg_node = node
                    break
            # 遍历所有的node,如果这个node在medical周围而且能够在bk_tree中找到类似的结果,则把这个结果搞进去
            # fg 为 bg_rect 的右侧的一个区域
            if bg_node:

                bg_xmin, bg_ymin, bg_xmax, bg_ymax = bg_node.bbox.rect
                bg_height = (bg_ymax - bg_ymin)
                bg_width = (bg_xmax - bg_xmin)
                fg_xmin = bg_xmax
                fg_xmax = fg_xmin + bg_width * 3
                fg_ymin = bg_ymin - bg_height
                fg_ymax = bg_ymax + bg_height
                fg_rect = [int(fg_xmin), int(fg_ymin), int(fg_xmax), int(fg_ymax)]
                prob_res = []
                for node in node_items.values():
                    if node.bbox.is_center_in(fg_rect):
                        if len(node.text) <= 1:
                            continue
                        if str_util.keep_num_char(node.text) == node.text and len(node.text) <= 4:
                            continue
                        if node.text == '年' or node.text == '日':
                            continue
                        prob_res.append(node.text)

                for idx, res in enumerate(prob_res):
                    bk_res = bk_tree.medical_institution_type().search_one(res,
                                                                           search_dist=2,
                                                                           min_len=2)
                    prob_res[idx] = bk_res
                filter_prob_res = list(filter(lambda x: x is not None, prob_res))
                if not filter_prob_res:
                    return
                if len(filter_prob_res) == 1:
                    medical_institution_type_item.content = filter_prob_res[0]
                    return
                else:
                    medical_institution_type_item.content = max(filter_prob_res, key=lambda x: len(x))
                    return
            else:
                # 顺便尝试搜索一下常见的医疗机构类型,避免bg没有被找到的情况
                fg_node_regex = {
                    "综合医院": ["^综合医院$", "类综合医院", "类型综合医院"],
                    "中医医院": ["医疗.{1,2}中医医院", "类中医医院"],
                    "社区卫生服务中心": ["社区卫生服务中心"],
                    "中西医结合医院": ["中西医结合医院"],
                    "非营利综合医院": ["非营利综合医院"],
                    "对外中医": ['对外中医$']
                }
                for node in node_items.values():
                    for institution_type in fg_node_regex:
                        for rl in fg_node_regex[institution_type]:
                            if re.search(rl, node.text):
                                medical_institution_type_item.content = institution_type
                                return
            return

        if medical_institution_type_item.content:

            text = medical_institution_type_item.content

            if text.startswith('型'):
                medical_institution_type_item.content = medical_institution_type_item.content[1:]
                return
            redundant_words = ['^医疗机构类型:', '^医疗.*构类型', '^医疗机构类', '^.{2,}机构类型', ]
            for rule in redundant_words:
                if re.search(rule, text):
                    logger.debug('institution {} is subtracted from {}'.format(re.sub(rule, '', text), text))
                    text = re.sub(rule, '', text)

            text = str_util.remove_redundant_patthern(text, redundant_words)
            res = bk_tree.medical_institution_type().search_one(text,
                                                                search_dist=1,
                                                                min_len=2)
            res_2 = None
            if res is None:
                # 尝试救一下:
                # 做一些特殊的对应,因为有的字样只有可能出现在某些特殊的内容中,比如 '外中'
                special_map = {'外中': '对外中医', '外综': '对外综合', '利中': '非营利中医医院',
                               '性医院': '综合性医院', '性医疗': '非营利性医疗机构'
                               }
                text_copy = text
                for sp in special_map:
                    if sp in text_copy:
                        text_copy = special_map[sp]
                        break
                if text_copy != text:
                    res_2 = bk_tree.medical_institution_type().search_one(text_copy, search_dist=1, min_len=2)

            if res is None and res_2 is None:
                medical_institution_type_item.content = text
                medical_institution_type_item.scores = [0]
            else:
                if res is not None:
                    medical_institution_type_item.content = res
                elif res_2 is not None:
                    medical_institution_type_item.content = res_2

                medical_institution_type_item.scores = [1]
                if text != res:
                    logger.debug('medical_institution_type bk_tree:')
                    logger.debug('\tOrigin: {}'.format(text))
                    logger.debug('\tItem in tree: {}'.format(res))
Ejemplo n.º 27
0
 def print_info():
     logger.debug('Get receiptno from barcode:')
     logger.debug('\torigin receiptno: {}'.format(origin_receiptno))
     logger.debug('\tbarcode: {}'.format(barcode))
     logger.debug('\treceiptno from barcode: {}'.format(receiptno_from_barcode))
Ejemplo n.º 28
0
    def _post_func_crnn_date(
        self,
        item_name: str,
        passed_nodes: Dict[str, TpNodeItem],
        node_items: Dict[str, TpNodeItem],
        img: np.ndarray,
    ):
        """
        对日期字段的候选框进行 merge,然后使用 carnn 进行重识别
        """
        # TO-DO,viz now is not supported
        # DEBUG_VIZ = False
        # 对这种情况做处理 '2017-11-28_9:20:38日'
        origin_text = ""
        for node in passed_nodes.values():
            origin_text += node.text

        # 首先,clean 掉所有的 可能的类似于 含有时分秒的问题,'2017-11-28_9:20:38'
        changed_res = date_util.filter_year_month_day(passed_nodes)
        if changed_res:
            for row in changed_res:
                logger.debug(
                    f"[{item_name}] change {row[0]} to {row[1]} due to it has sec and hour info"
                )

        # 这里应该是要过滤掉所有的数字部分过多的,一看就是序列号一类的node
        # 这个方法会对如同 '2017._00___82___84付___2017' 的情形过滤掉
        passed_nodes = OrderedDict(
            filter(
                lambda x: str_util.count_max_continous_num(x[1].text) < 8,
                passed_nodes.items(),
            ))

        if len(passed_nodes) == 0:
            return

        # 将在一行上的bbox找出
        nodes_in_line = NodeItemGroup.find_row_lines(passed_nodes,
                                                     y_thresh=0.6)
        # 先把有number的过滤出来
        nodes_in_line = list(
            filter(
                lambda x: date_util.contain_number(x.content()) or "年" in x.
                content() or "日" in x.content(),
                nodes_in_line,
            ))

        # crnn 识别区域为,y的范围参考有数字的部分的区域,x的范围为所有区域的xmin,xmax 的最值
        if len(nodes_in_line) == 0:
            return

        crnn_xmin = min([node.bbox.rect[0] for node in nodes_in_line])
        crnn_xmax = max([node.bbox.rect[2] for node in nodes_in_line])

        if len(nodes_in_line) == 1:
            nodes_in_line = nodes_in_line[0]
        else:
            # 选择合并后宽度最长,且有数字的那一个作为nodes_in_line来获得数字
            nodes_in_line = list(
                filter(lambda x: date_util.contain_number(x.content()),
                       nodes_in_line))
            nodes_in_line = max(nodes_in_line, key=lambda x: x.bbox.width)
        crnn_ymin = nodes_in_line.bbox.rect[1]
        crnn_ymax = nodes_in_line.bbox.rect[3]
        # 重识别的 rect 坐标要用原图坐标

        crnn_res, scores = self._crnn_util.run_number_space(
            img, [crnn_xmin, crnn_ymin, crnn_xmax, crnn_ymax])

        crnn_res, scores = self.check_crnn_res_and_rerecognize(
            crnn_res, scores, img, crnn_xmin, crnn_ymin, crnn_xmax, crnn_ymax)

        date_info = [split for split in crnn_res.split("_") if split]
        useful_info = date_util.get_useful_info(crnn_res.replace("_", " "))
        if len(date_info) >= 3:
            # 如果node 能够形成完美的格式,即年份识别出四个字算1 , 如果识别出两个字,算1/2
            split_res = date_util.parse_date(crnn_res,
                                             min_year=2000,
                                             max_year=2040)
            if split_res.most_possible_item:
                res = split_res.most_possible_item.to_string()
                logger.debug(
                    f"{item_name} res {crnn_res} is formated to {res}")
            else:
                res = None

            crnn_scores = str_util.date_format_scores(
                *date_info[:3])  # 现在只对前三的内容做处理

        if len(date_info) < 3 or res is None:
            logger.debug(
                f"{item_name}  res {crnn_res} can not format , use normal method"
            )
            res = date_util.get_format_data_from_crnn_num_model_res(
                crnn_res.replace("_", ""))
            crnn_scores = str_util.date_format_scores(
                useful_info["useful_year"],
                useful_info["useful_month"],
                useful_info["useful_day"],
            )

        # 如果 crnn 的结果失败了,再使用原来的方法跑一边
        if res is None or not date_util.is_legal_format_date_str(res):
            nodes = NodeItemGroup.recover_node_item_dict(
                nodes_in_line.node_items)
            split_res = self._post_func_date(item_name, nodes, node_items, img)

            if split_res is not None:
                # recover year of split_res
                split_res_date = datetime.strptime(split_res[0], "%Y-%m-%d")
                if not date_util.is_legal_year(split_res_date.year):
                    predict_year = NodeItemGroup.get_year_predict(node_items)
                    predict_date = str(
                        split_res_date.replace(year=predict_year).date())
                    split_res = (predict_date, split_res[1])
                recover_from_crnn = date_util.recover_info_from_useful_info(
                    split_res, useful_info)
                logger.debug(
                    "recover from crnn , org is {} , recover res is {}".format(
                        split_res, recover_from_crnn))
                split_res = recover_from_crnn
            else:
                if (useful_info.get("useful_month", None) is not None
                        and useful_info.get("useful_day", None) is not None):
                    # 尝试从所有的 node_items 中找到前三年的相关信息:
                    max_count_year = None
                    max_count = -1
                    cur_year = datetime.now().year
                    for year in range(cur_year - 3, cur_year + 1):
                        count_of_year = NodeItemGroup.regex_filter_count(
                            node_items, str(year))
                        if count_of_year > max_count:
                            max_count = count_of_year
                            max_count_year = year
                    if max_count != -1:
                        # find a year
                        predict_date = dt.date(
                            year=max_count_year,
                            month=int(useful_info["useful_month"]),
                            day=int(useful_info["useful_day"]),
                        )
                        predict_date = str(predict_date)
                        # 尝试从use_ful_info 中恢复部分数据
                        split_res = (predict_date, [1])
                        logger.debug("[{}] predict date from crnn , {}".format(
                            item_name, predict_date))
            logger.debug("[{}] Origin texts: {},"
                         "CRNN result: {},"
                         "Merge result: {},"
                         "Split result: {}".format(item_name, origin_text,
                                                   crnn_res, res, split_res))

            # name = item_name + str(self.global_test)
            # viz_crnn_debug(passed_nodes, nodes_in_line, img,name,origin_text)
            # self.global_test+=1

            return split_res

        # TODO: get_format_data_from_crnn_num_model_res 中需要返回到底用了哪几个位置的数字
        mean_score = NodeItemGroup.cal_mean_score(nodes_in_line)
        # return res, [mean_score]
        return res, crnn_scores
Ejemplo n.º 29
0
    def _init_fg_items(self, class_name: str, conf: Dict) -> Dict[str, FGItem]:
        """
        处理函数加载优先级:
        1. 配置文件配置了 item_pre_func/item_post_func,寻找指定的函数
        2. 配置文件没有配置 item_pre_func/item_post_func,按照约定找:_pre_func_{item_name} / _post_func_{item_name}
        - 预处理:如果不存在则没有预处理函数
        - 后处理:如果不存在则使用 _post_func_max_w_regex

        Args:
            class_name:
            conf:

        Returns:

        """
        if "fg_items" not in conf:
            raise ConfigException(f"fg_items not exist in {class_name}.yml")

        res = {}
        for item in conf["fg_items"]:
            item_name = item.get("item_name", None)
            item_show_name = item.get("show_name", None)

            if item_name is None:
                raise ConfigException(
                    f"[{class_name}] FG item [{item_name}] miss [item_name] key"
                )

            item_pre_func = self._get_pre_post_func(item, "item_pre_func",
                                                    class_name)
            item_post_func = self._get_pre_post_func(item, "item_post_func",
                                                     class_name)

            # 如果模板没有填后处理函数,则默认使用 _post_func_max_w_regex 作为后处理函数
            if item_post_func is None:
                item_post_func = self._get_item_func("_post_func_max_w_regex")

            # 获得 filters
            if self.is_tp_conf:
                filter_areas = []
                if "area" in item:
                    filter_areas.append({
                        "area":
                        item["area"],
                        "w":
                        1,
                        "ioo_thresh":
                        item.get("ioo_thresh", 0),
                    })

                if len(filter_areas) == 0:
                    filter_areas = None

                filter_regexs = item.get("filter_regexs", [])
                for r in filter_regexs:
                    r["w"] = 1

                filter_confs = {
                    "filter_areas": filter_areas,
                    "filter_contents": item.get("filter_contents", None),
                    "filter_regexs": filter_regexs,
                }
            else:
                filter_confs = {
                    "filter_areas": item.get("filter_areas", None),
                    "filter_contents": item.get("filter_contents", None),
                    "filter_regexs": item.get("filter_regexs", None),
                }

            # 获取 search_strategy 的参数
            search_strategy = None
            search_strategy_item = item.get("search_strategy", None)
            if search_strategy_item is not None:
                search_mode = search_strategy_item.get("mode", None)
                if search_mode is None:
                    s = f"[{class_name}] FG item [{item_name}] set search_strategy, but not set [mode] key"
                    raise ConfigException(s)

                w_pad = search_strategy_item.get("w_pad", 0)
                h_pad = search_strategy_item.get("h_pad", 0)
                logger.debug(
                    f"Enlarge {class_name} [{item_name}] search area: w_pad {w_pad} h_pad {h_pad}"
                )
                search_strategy = EnlargeSearchStrategy(w_pad, h_pad)

            post_regex_filter_func_name = item.get(
                "item_post_regex_filter_func")
            post_regex_filter_func = self._get_item_func(
                post_regex_filter_func_name)

            res[item_name] = FGItem(
                item_name,
                item_show_name,
                filter_confs,
                item_pre_func,
                item_post_func,
                post_regex_filter_func,
                should_output=item.get("should_output", True),
                search_strategy=search_strategy,
            )

        return res
Ejemplo n.º 30
0
    def iou_block_search(self, content_similarity: Dict[str, Dict[str, float]],
                         node_items: Dict[str,
                                          TpNodeItem], ban_offset_uids: Set):
        w, h = 0, 0
        for above_item in self.above_items.values():
            w = max(w, above_item.bbox.right)
            h = max(h, above_item.bbox.bottom)

        for node_item in node_items.values():
            w = max(w, node_item.trans_bbox.right)
            h = max(h, node_item.trans_bbox.bottom)
        # w,h 为所有的框的上限位置,为搜索的上届,做二分法搜索
        now_offset_center = [0, 0]
        max_search_dis = max(w, h) / 6.0
        split_size = 10

        record_of_all_best = None
        while max_search_dis > 1:
            block_size = max_search_dis / split_size
            left_up_point = [
                v - (max_search_dis / 2.0) for v in now_offset_center
            ]
            new_center = [0, 0]
            result_offset_center = [v for v in now_offset_center]
            best_similarity_value = 0

            for i in range(split_size):
                for j in range(split_size):
                    new_center[0] = left_up_point[0] + i * block_size
                    new_center[1] = left_up_point[1] + j * block_size
                    similarity_value, hit_num, hit_can_not_miss = self.cal_match_similarity_value(
                        content_similarity, node_items, new_center,
                        ban_offset_uids)
                    offset = [v for v in new_center]
                    if not record_of_all_best:
                        record_of_all_best = (hit_can_not_miss, hit_num,
                                              similarity_value, offset)
                    else:
                        # 现在里面已经有值了
                        current = (hit_can_not_miss, hit_num, similarity_value,
                                   offset)
                        if self.get_score_of_each_offest(
                                current) > self.get_score_of_each_offest(
                                    record_of_all_best):
                            record_of_all_best = current

                    # 原先的处理逻辑
                    # if similarity_value > best_similarity_value:
                    #     best_similarity_value = similarity_value
                    #     result_offset_center = [v for v in new_center]

            #            now_offset_center = [v for v in result_offset_center]
            max_search_dis = block_size * 2

        # if self.exp_data:
        #     self.exp_data.set_above_matched_idxes(best_label_node_idx_4_debug)
        logger.debug(
            'above offset with hit_can_not_miss :{} , with hit {}'.format(
                record_of_all_best[0], record_of_all_best[1]))

        return record_of_all_best[3]