예제 #1
0
def fix_result_by_dictionary(result_dict, dictionary):
    list_seller = dictionary['seller']
    list_address = dictionary['address']
    seller_str = ' '.join(result_dict['SELLER']['value'])
    min_cer = 1
    min_seller_string = ''
    min_seller = None
    for seller in list_seller:
        seller_ori_str = ' '.join(seller['SELLER'])
        cer = cer_loss_one_image(seller_str, seller_ori_str)
        if cer < min_cer:
            min_cer = cer
            min_seller_string = seller_ori_str
            min_seller = seller['SELLER']
            # print(min_cer, seller_ori_str)
    if min_cer < 0.3 and min_cer > 0:
        # print('Fix SELLER', seller_str, '----->', min_seller_string)
        result_dict['SELLER']['value'] = min_seller

    address_str = ' '.join(result_dict['ADDRESS']['value'])
    min_cer = 1
    min_address_string = ''
    min_address = None
    for address in list_address:
        address_ori_str = ' '.join(address['ADDRESS'])
        cer = cer_loss_one_image(address_str, address_ori_str)
        if cer < min_cer:
            min_cer = cer
            min_address_string = address_ori_str
            min_address = address['ADDRESS']
            # print(min_cer, min_address_string)
    if min_cer < 0.3 and min_cer > 0:
        # print('Fix ADDRESS', address_str, '----->', min_address_string)
        result_dict['ADDRESS']['value'] = min_address
예제 #2
0
def validate_TOTAL_COST_keys(input_str, cer_thres=0.2):
    lower_totalcost = input_str.lower()
    min_cer = 1
    for k in TOTAL_COST_keys:
        lower_k = k.lower()
        cer = cer_loss_one_image(lower_totalcost, lower_k)
        if cer < min_cer:
            min_cer = cer
    return True if min_cer < cer_thres else False
예제 #3
0
def fix_totalcost(list_totalcost, output_ocr_path = None):
    from mc_ocr.utils.common import cer_loss_one_image, get_list_icdar_poly
    from mc_ocr.key_info_extraction.create_train_data import find_TOTALCOST_val_poly
    final_list = []
    if len(list_totalcost) == 2:
        if validate_TOTAL_COST_amount(list_totalcost[0]) or validate_TOTAL_COST_keys(list_totalcost[1]):
            list_totalcost[0], list_totalcost[1] = list_totalcost[1], list_totalcost[0]
        final_list = list_totalcost
    elif len(list_totalcost) >= 3:
        # print('\n', list_totalcost)
        min_cer = 1
        min_str = ''
        for idx, totalcost in enumerate(list_totalcost):
            lower_totalcost = totalcost.lower()
            for k in TOTAL_COST_keys:
                lower_k = k.lower()
                cer = cer_loss_one_image(lower_totalcost, lower_k)
                if cer < min_cer:
                    min_cer = cer
                    min_str = totalcost

        final_list.append(min_str)

        max_len = 0
        max_str = ''
        for idx, totalcost in enumerate(list_totalcost):
            lower_totalcost = totalcost.lower()
            if validate_TOTAL_COST_amount(lower_totalcost):
                if len(totalcost) > max_len:
                    max_len = len(totalcost)
                    max_str = totalcost

        final_list.append(max_str)
        # print(img_name,final_list)
    elif len(list_totalcost)==1:
        #print(os.path.basename(output_ocr_path))
        #print(list_totalcost[0])
        list_icdar_poly = get_list_icdar_poly(output_ocr_path)
        for icdar_pol in list_icdar_poly:
            if icdar_pol.value == list_totalcost[0] and validate_TOTAL_COST_keys(icdar_pol.value, cer_thres=0.2):
                TOTAL_COST_value = find_TOTALCOST_val_poly(keys_poly=icdar_pol,
                                        list_poly=list_icdar_poly, expand_ratio=0.2)
                if TOTAL_COST_value is not None:
                    #print(TOTAL_COST_value)
                    list_totalcost.append(TOTAL_COST_value)
        final_list = list_totalcost
    else:
        final_list = list_totalcost
    return final_list
예제 #4
0
def fix_result_by_rule_based(result_dict, ocr_path, dictionary=None):
    # Fix address than can not read
    address_to_fix = ''
    num_the_same_seller = 0
    if len(result_dict['ADDRESS']['value']) == 0 and len(
            result_dict['SELLER']['value']) > 0:
        full_seller = ' '.join(result_dict['SELLER']['value'])
        for store in dictionary['store']:
            if store['count'] > 10:
                full_store_seller = ' '.join(store['SELLER'])
                if cer_loss_one_image(full_seller, full_store_seller) < 0.1:
                    num_the_same_seller += 1
                    address_to_fix = store['ADDRESS']
    if num_the_same_seller == 1:
        print(os.path.basename(ocr_path), address_to_fix)
        result_dict['ADDRESS']['value'] = address_to_fix

    # add regex datetime
    for idx, time in enumerate(result_dict['TIMESTAMP']['value']):
        if len(time) > 30:
            # print(file, result_dict['TIMESTAMP']['value'][idx], '----------------------------------->',regex)
            result_dict['TIMESTAMP']['value'][idx] = fix_datetime(time)

    # simple rule to fix order of date time
    if len(result_dict['TIMESTAMP']['value']) == 2:
        bboxes1 = result_dict['TIMESTAMP']['bboxes'][0].split(',')
        bboxes1 = [int(coor) for coor in bboxes1]
        first_center = [
            (bboxes1[0] + bboxes1[2] + bboxes1[4] + bboxes1[6]) / 4,
            (bboxes1[1] + bboxes1[3] + bboxes1[5] + bboxes1[7]) / 4
        ]
        bboxes2 = result_dict['TIMESTAMP']['bboxes'][1].split(',')
        bboxes2 = [int(coor) for coor in bboxes2]
        second_center = [
            (bboxes2[0] + bboxes2[2] + bboxes2[4] + bboxes2[6]) / 4,
            (bboxes2[1] + bboxes2[3] + bboxes2[5] + bboxes2[7]) / 4
        ]
        if second_center[0] < first_center[0]:
            # print(os.path.basename(ocr_path).replace('.txt',''))
            # print(os.path.basename(ocr_path), result_dict['TIMESTAMP']['value'])
            result_dict['TIMESTAMP']['value'][0], result_dict['TIMESTAMP']['value'][1] = \
                result_dict['TIMESTAMP']['value'][1], result_dict['TIMESTAMP']['value'][0]
            result_dict['TIMESTAMP']['bboxes'][0], result_dict['TIMESTAMP']['bboxes'][1] = \
                result_dict['TIMESTAMP']['bboxes'][1], result_dict['TIMESTAMP']['bboxes'][0]
            # print('---------------------------->',result_dict['TIMESTAMP']['value'])

    result_dict['TOTAL_COST']['value'] = fix_totalcost(
        result_dict['TOTAL_COST']['value'], output_ocr_path=ocr_path)
예제 #5
0
def validate_ADDRESS(list_address, input_str, cer_thres=0.2):
    if len(input_str) < 10:
        return False
    input_str = input_str.lower()
    min_cer = 1
    min_str = ''
    for s in list_address:
        if s['count'] > 1:
            for line in s['ADDRESS']:
                lower_line = line.lower()
                cer = cer_loss_one_image(lower_line, input_str)
                if cer < min_cer:
                    min_cer = cer
                    min_str = lower_line
    if min_cer < cer_thres:
        print(round(min_cer, 2), input_str, '-----', min_str)
    return True if min_cer < cer_thres else False
예제 #6
0
def parse_anno_from_csv_to_icdar_result(csv_file, icdar_dir, output_dir, img_dir=None, debug=False):
    with open(csv_file) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        first_line = True

        total_boxes_not_match = 0
        total_boxes = 0
        for n, row in enumerate(csv_reader):
            if first_line:
                first_line = False
                continue
            if n < 0:
                continue
            img_name = row[0]
            print('\n' + str(n), img_name)
            # if 'mcocr_public_145014smasw' not in img_name:
            #     continue
            src_img = cv2.imread(os.path.join(img_dir, img_name))
            # src_img = cv2.imread(os.path.join(img_dir, 'viz_' + img_name))

            # Read all poly from training data
            list_gt_poly = get_list_gt_poly(row)

            # Read all poly from icdar
            icdar_path = os.path.join(icdar_dir, img_name.replace('.jpg', '.txt'))
            list_icdar_poly = get_list_icdar_poly(icdar_path)

            # Compare iou and parse text from training data
            for pol in list_gt_poly:
                total_boxes += 1
                match = False
                if debug:
                    gt_img = src_img.copy()
                    gt_box = np.array(pol.list_pts).astype(np.int32)
                    cv2.polylines(gt_img, [gt_box], True, color=color_map[pol.type], thickness=2)
                max_iou = 0
                for icdar_pol in list_icdar_poly:
                    iou = IoU(pol, icdar_pol, False)
                    if iou > max_iou:
                        max_iou = iou
                    cer = cer_loss_one_image(pol.value, icdar_pol.value)
                    if debug:
                        pred_img = src_img.copy()
                        pred_box = np.array(icdar_pol.list_pts).astype(np.int32)
                        cv2.polylines(pred_img, [pred_box], True, color=color_map[pol.type], thickness=2)
                    if iou > 0.3:
                        match = True
                        print('gt  :', pol.value)
                        print('pred:', icdar_pol.value)
                        print('cer', round(cer, 3), ',iou', iou)
                        icdar_pol.type = pol.type

                if not match:
                    total_boxes_not_match += 1
                    print(' not match gt  :', pol.value)
                    print('Max_iou', max_iou)
                    if debug:
                        gt_img_res = cv2.resize(gt_img, (int(gt_img.shape[1]/2),int(gt_img.shape[0]/2)))
                        cv2.imshow('not match gt box', gt_img_res)
                        cv2.waitKey(0)

            # save to output file
            output_icdar_path = os.path.join(output_dir, img_name.replace('.jpg', '.txt'))
            output_icdar_txt = ''
            for icdar_pol in list_icdar_poly:
                output_icdar_txt += icdar_pol.to_icdar_line(map_type=type_map) + '\n'

            output_icdar_txt = output_icdar_txt.rstrip('\n')
            with open(output_icdar_path, 'w', encoding='utf-8') as f:
                f.write(output_icdar_txt)
            if total_boxes > 0:
                print('Total not match', total_boxes_not_match, 'total boxes', total_boxes, 'not match ratio',
                      round(total_boxes_not_match / total_boxes, 3))
예제 #7
0
def get_store_from_csv_to_json(csv_file, output_json_file):
    list_store = []
    list_seller = []
    list_address = []
    final_list_store = []
    final_list_seller = []
    final_list_address = []
    with open(csv_file) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        first_line = True
        output_row = []
        for n, row in enumerate(csv_reader):
            if first_line:
                first_line = False
                output_row.append(row)
                continue
            if n < 0:
                continue
            img_name = row[0]
            print(n, img_name)
            boxes = ast.literal_eval(row[1])
            key, value = row[3].split('|||'), row[2].split('|||')
            store = {'SELLER': [], 'ADDRESS': [], 'count': 0}
            seller = {'SELLER': [], 'count': 0}
            address = {'ADDRESS': [], 'count': 0}
            for idx, k in enumerate(key):
                if k in store.keys():
                    store[k].append(value[idx])
                if k in seller.keys():
                    seller[k].append(value[idx])
                if k in address.keys():
                    address[k].append(value[idx])
            check_existed_store(store_candidate=store, list_store=list_store)
            check_existed_seller(seller_candidate=seller,
                                 list_seller=list_seller)
            check_existed_address(address_candidate=address,
                                  list_address=list_address)

        # Check duplicate store
        n = len(list_store)
        count = 0
        for i in range(0, n):
            if list_store[i]['count'] == 0:
                continue
            first_seller = ' '.join(list_store[i]['SELLER'])
            first_address = ' '.join(list_store[i]['ADDRESS'])
            for j in range(i + 1, n):
                if list_store[j]['count'] == 0:
                    continue
                second_seller = ' '.join(list_store[j]['SELLER'])
                second_address = ' '.join(list_store[j]['ADDRESS'])
                if cer_loss_one_image(first_seller, second_seller) < 0.4 and \
                        cer_loss_one_image(first_address, second_address) < 0.4:
                    if list_store[i]['count'] < list_store[j]['count']:
                        list_store[i]['count'] = 0
                        list_store[j]['count'] += 1
                    if list_store[j]['count'] < list_store[i]['count']:
                        list_store[j]['count'] = 0
                        list_store[i]['count'] += 1
                count += 1
                # print(count, i,j,n)

        count = 0
        for store in list_store:
            if store['count'] > 1:
                count += 1
                print(count, store)
                final_list_store.append(store)

        print('List seller>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
        # Check duplicate seller
        n = len(list_seller)
        count = 0
        for i in range(0, n):
            if list_seller[i]['count'] == 0:
                continue
            first_seller = ' '.join(list_seller[i]['SELLER'])
            for j in range(i + 1, n):
                if list_seller[j]['count'] == 0:
                    continue
                second_seller = ' '.join(list_seller[j]['SELLER'])
                if cer_loss_one_image(first_seller, second_seller) < 0.4:
                    if list_seller[i]['count'] < list_seller[j]['count']:
                        list_seller[i]['count'] = 0
                        list_seller[j]['count'] += 1
                    if list_seller[j]['count'] < list_seller[i]['count']:
                        list_seller[j]['count'] = 0
                        list_seller[i]['count'] += 1
                count += 1
                # print(count, i, j, n)

        count = 0
        for seller in list_seller:
            seller_str = ' '.join(seller['SELLER'])
            if seller['count'] > 0 and len(seller_str) > 2:
                count += 1
                print(count, seller)
                final_list_seller.append(seller)

        print('List address>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
        # Check duplicate address
        n = len(list_address)
        count = 0
        for i in range(0, n):
            if list_address[i]['count'] == 0:
                continue
            first_address = ' '.join(list_address[i]['ADDRESS'])
            for j in range(i + 1, n):
                if list_address[j]['count'] == 0:
                    continue
                second_address = ' '.join(list_address[j]['ADDRESS'])
                if cer_loss_one_image(first_address, second_address) < 0.6:
                    if list_address[i]['count'] < list_address[j]['count']:
                        list_address[i]['count'] = 0
                        list_address[j]['count'] += 1
                    if list_address[j]['count'] < list_address[i]['count']:
                        list_address[j]['count'] = 0
                        list_address[i]['count'] += 1
                count += 1
                # print(count, i, j, n)

        count = 0
        for address in list_address:
            address_str = ' '.join(address['ADDRESS'])
            if address['count'] > 0 and len(address_str) > 2:
                count += 1
                print(count, address)
                final_list_address.append(address)

    final_dict = {'store': [], 'seller': [], 'address': []}
    final_dict['store'] = final_list_store
    final_dict['seller'] = final_list_seller
    final_dict['address'] = final_list_address
    with open(output_json_file, 'w', encoding='utf-8') as outfile:
        json.dump(final_dict, outfile)

    return final_dict