Esempio n. 1
0
def create_data_pick_csv_train_val(train_dir, train_ratio=0.92):
    list_files = get_list_file_in_folder(os.path.join(train_dir, 'images'))
    num_total = len(list_files)
    num_train = int(num_total * train_ratio)
    num_val = num_total - num_train

    random.shuffle(list_files)
    list_train = list_files[:num_train]
    list_val = list_files[num_train + 1:]

    train_txt_list = []
    for idx, f in enumerate(list_train):
        line = ','.join([str(idx + 1), 'receipts', f])
        train_txt_list.append(line + '\n')

    with open(os.path.join(train_dir, 'train_list.csv'), mode='w', encoding='utf-8') as f:
        f.writelines(train_txt_list)

    val_txt_list = []
    for idx, f in enumerate(list_val):
        line = ','.join([str(idx + 1), 'receipts', f])
        val_txt_list.append(line + '\n')

    with open(os.path.join(train_dir, 'val_list.csv'), mode='w', encoding='utf-8') as f:
        f.writelines(val_txt_list)
    print('Done')
Esempio n. 2
0
def main():
    begin_init = time.time()
    global anno_path

    if not os.path.exists(output_txt_dir):
        os.makedirs(output_txt_dir)
    if not os.path.exists(output_viz_dir):
        os.makedirs(output_viz_dir)
    if not os.path.exists(output_rotated_img_dir):
        os.makedirs(output_rotated_img_dir)

    box_rectify = init_box_rectify_model(weight_path)
    end_init = time.time()
    print('Init models time:', end_init - begin_init, 'seconds')
    begin = time.time()

    list_img_path = get_list_file_in_folder(img_dir)
    list_img_path = sorted(list_img_path)
    for idx, img_name in enumerate(list_img_path):
        print('\n', idx, 'Inference', img_name)
        test_img = cv2.imread(os.path.join(img_dir, img_name))
        begin_detector = time.time()
        anno_path = os.path.join(anno_dir, img_name.replace('.jpg', '.txt'))
        boxes_list = get_list_boxes_from_icdar(anno_path)
        boxes_list = drop_box(boxes_list, drop_gap=rot_drop_thresh)
        rotation = get_mean_horizontal_angle(boxes_list, False)
        img_rotated, boxes_list = rotate_image_bbox_angle(
            test_img, boxes_list, rotation)

        degre = calculate_page_orient(box_rectify, img_rotated, boxes_list)
        img_rotated, boxes_list = rotate_image_bbox_angle(
            img_rotated, boxes_list, degre)
        boxes_list = filter_90_box(boxes_list)
        end_detector = time.time()
        print('get boxes from icdar time:', end_detector - begin_detector,
              'seconds')

        output_txt_path = os.path.join(
            output_txt_dir,
            os.path.basename(img_name).split('.')[0] + '.txt')
        output_viz_path = os.path.join(output_viz_dir,
                                       os.path.basename(img_name))
        output_rotated_img_path = os.path.join(output_rotated_img_dir,
                                               os.path.basename(img_name))
        if write_rotated_img:
            cv2.imwrite(output_rotated_img_path, img_rotated)
        if write_file:
            write_output(boxes_list, output_txt_path)
        if visualize:
            viz_icdar(img_rotated, output_txt_path, output_viz_path)
            end_visualize = time.time()
            print('Visualize time:', end_visualize - end_detector, 'seconds')

    end = time.time()
    speed = (end - begin) / len(list_img_path)
    print('Processing time:', end - begin, 'seconds. Speed:', round(speed, 4),
          'second/image')
    print('Done')
Esempio n. 3
0
def create_data_pick_boxes_and_transcripts(icdar_dir, output_dir):
    list_file = get_list_file_in_folder(icdar_dir, ext=['txt'])
    for idx, anno in enumerate(list_file):
        print(idx, anno)
        with open(os.path.join(icdar_dir, anno), mode='r', encoding='utf-8') as f:
            list_bboxes = f.readlines()
        for idx, line in enumerate(list_bboxes):
            list_bboxes[idx] = str(idx + 1) + ',' + line
        with open(os.path.join(output_dir, anno.replace('.txt', '.tsv')), mode='wt', encoding='utf-8') as f:
            f.writelines(list_bboxes)
Esempio n. 4
0
def viz_icdar_multi(img_dir, anno_dir, save_viz_dir, extract_kie_type=False, ignor_type=[1]):
    list_files = get_list_file_in_folder(img_dir)
    for idx, file in enumerate(list_files):
        if idx < 0:
            continue
        # if 'mcocr_public_145014smasw' not in file:
        #     continue
        print(idx, file)
        img_path = os.path.join(img_dir, file)
        anno_path = os.path.join(anno_dir, file.replace('.jpg', '.txt'))
        save_img_path = os.path.join(save_viz_dir, file)
        viz_icdar(img_path, anno_path, save_img_path, extract_kie_type, ignor_type)
Esempio n. 5
0
def filter_space_OCR(ocr_dir):
    list_files = get_list_file_in_folder(ocr_dir, ext=['txt'])
    for idx, file in enumerate(list_files):
        #print(idx, file)
        with open(os.path.join(ocr_dir, file), 'r', encoding='utf-8') as f:
            anno_txt = f.readlines()
        final_res = ''
        for anno in anno_txt:
            fix_anno = anno.rstrip('\n').rstrip(' ')
            if fix_anno != anno.rstrip('\n'):
                print(idx, file, '. Fix space')
            final_res += fix_anno + '\n'
        final_res = final_res.rstrip('\n')
Esempio n. 6
0
def compare_result(first_res_dir, second_res_dir, ext='txt'):
    first_res_img_dir = first_res_dir.replace('/txt', '/imgs')
    second_res_img_dir = second_res_dir.replace('/txt', '/imgs')
    first_res_viz_dir = first_res_dir.replace('/txt', '/viz_imgs')
    second_res_viz_dir = second_res_dir.replace('/txt', '/viz_imgs')
    list_files = get_list_file_in_folder(first_res_dir, ext=[ext])
    for idx, file in enumerate(list_files):
        print(idx, file)
        with open(os.path.join(first_res_dir, file),
                  mode='r',
                  encoding='utf-8') as f:
            first_txt = f.readlines()
        with open(os.path.join(second_res_dir, file),
                  mode='r',
                  encoding='utf-8') as f:
            second_txt = f.readlines()
        diff = False
        first_img = cv2.imread(
            os.path.join(first_res_img_dir, file.replace('.txt', '.jpg')))
        second_img = cv2.imread(
            os.path.join(second_res_img_dir, file.replace('.txt', '.jpg')))
        first_viz_img = cv2.imread(
            os.path.join(first_res_viz_dir, file.replace('.txt', '.jpg')))
        second_viz_img = cv2.imread(
            os.path.join(second_res_viz_dir, file.replace('.txt', '.jpg')))
        for line in first_txt:
            if line not in second_txt:
                print('-------------------- not in second', line)
                line = line.rstrip(',\n')
                # coors = line.split('\t')[0]
                # segment_pts = [int(f) for f in coors.split(',')]
                # box = np.array(segment_pts).astype(np.int32).reshape(-1, 2)
                # cv2.polylines(first_img, [box], True, color=(0, 255, 0), thickness=2)
                diff = True
        for line in second_txt:
            if line not in first_txt:
                print('-------------------- not in first', line)
                line = line.rstrip(',\n')
                # coors = line.split('\t')[0]
                # segment_pts = [int(f) for f in coors.split(',')]
                # box = np.array(segment_pts).astype(np.int32).reshape(-1, 2)
                # cv2.polylines(second_img, [box], True, color=(255, 0, 0), thickness=2)
                diff = True
        if diff:
            # cv2.imshow('first img', first_img)
            # cv2.imshow('second img', second_img)
            # cv2.imshow('first viz img', first_viz_img)
            # cv2.imshow('second viz img', second_viz_img)
            cv2.waitKey(0)
Esempio n. 7
0
def viz_output_of_pick(img_dir, output_txt_dir, output_viz_dir):
    list_output_txt = get_list_file_in_folder(output_txt_dir, ext=['txt'])
    list_output_txt = sorted(list_output_txt)
    for n, file in enumerate(list_output_txt):
        print(n, file)
        # if n <60:
        #     continue
        with open(os.path.join(output_txt_dir, file), mode='r', encoding='utf-8') as f:
            output_txt = f.readlines()

        list_poly = []
        for line in output_txt:
            coordinates, type, text = line.replace('\n', '').split('\t')
            find_poly = poly(coordinates, type=inv_type_map[type], value=text)
            list_poly.append(find_poly)

        img_name = file.replace('.txt', '.jpg')
        image = cv2.imread(os.path.join(img_dir, img_name))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        viz_poly(img=image,
                 list_poly=list_poly,
                 save_viz_path=os.path.join(output_viz_dir, img_name))
Esempio n. 8
0
def viz_same_img_in_different_dirs(first_dir, second_dir, list_path=None, resize_ratio=1.0):
    list_files = get_list_file_in_folder(first_dir)
    list_err_images = None
    if list_path is not None:
        list_err_images = get_list_error_imgs(list_path)
    for n, file in enumerate(list_files):
        if list_err_images is not None:
            if file.replace('.jpg', '') not in list_err_images:
                continue
        print(n, file)
        # if n<160:
        #     continue
        first_img_path = os.path.join(first_dir, file)
        first_img = cv2.imread(first_img_path)
        first_img_res = cv2.resize(first_img,
                                   (int(resize_ratio * first_img.shape[1]), int(resize_ratio * first_img.shape[0])))
        cv2.imshow('first_img', first_img_res)

        second_img_path = os.path.join(second_dir, file.replace('.jpg', '') + '_ model_epoch_639_minibatch_243000_ 0.1.jpg')
        second_img = cv2.imread(second_img_path)
        second_img_res = cv2.resize(second_img,
                                    (int(resize_ratio * second_img.shape[1]), int(resize_ratio * second_img.shape[0])))
        cv2.imshow('second_img', second_img_res)
        cv2.waitKey(0)
Esempio n. 9
0
def main():
    begin_init = time.time()
    global anno_path
    classifier = init_models(gpu=gpu)
    end_init = time.time()
    print('Init models time:', end_init - begin_init, 'seconds')
    begin = time.time()
    list_img_path = []
    if img_path != '':
        list_img_path.append(img_path)
    else:
        list_img_path = get_list_file_in_folder(img_dir)
    list_img_path = sorted(list_img_path)
    for idx, img_name in enumerate(list_img_path):
        if idx < 0:
            continue
        print('\n', idx, 'Inference', img_name)

        test_img = cv2.imread(os.path.join(img_dir, img_name))
        begin_detector = time.time()
        if img_path == '':
            anno_path = os.path.join(anno_dir,
                                     img_name.replace('.jpg', '.txt'))
        boxes_list = get_list_boxes_from_icdar(anno_path)

        end_detector = time.time()
        print('get boxes from icdar time:', end_detector - begin_detector,
              'seconds')

        # multiscale ocr

        list_values = []
        list_probs = []
        total_boxes = len(boxes_list)

        # 1 Extend x, no extend y
        boxes_data = get_boxes_data(test_img,
                                    boxes_list,
                                    extend_box=True,
                                    min_extend_y=0,
                                    extend_y_ratio=0)
        values, probs = classifier.inference(boxes_data, debug=False)
        list_values.append(values)
        list_probs.append(probs)

        # 2 extend y by 10%
        boxes_data = get_boxes_data(test_img,
                                    boxes_list,
                                    extend_box=True,
                                    min_extend_y=2,
                                    extend_y_ratio=0.1)
        values, probs = classifier.inference(boxes_data, debug=False)
        list_values.append(values)
        list_probs.append(probs)

        # 3 extend y by 20%
        boxes_data = get_boxes_data(test_img,
                                    boxes_list,
                                    extend_box=True,
                                    min_extend_y=4,
                                    extend_y_ratio=0.2)
        values, probs = classifier.inference(boxes_data, debug=False)
        list_values.append(values)
        list_probs.append(probs)

        # combine final values and probs
        final_values = []
        final_probs = []
        for idx in range(total_boxes):
            max_prob = list_probs[0][idx]
            max_value = list_values[0][idx]
            for n in range(1, len(list_values)):
                if list_probs[n][idx] > max_prob:
                    max_prob = list_probs[n][idx]
                    max_value = list_values[n][idx]

            final_values.append(max_value)
            final_probs.append(max_prob)

        end_classifier = time.time()
        print('Multiscale OCR time:', end_classifier - end_detector, 'seconds')
        print('Total predict time:', end_classifier - begin_detector,
              'seconds')
        output_txt_path = os.path.join(
            cls_out_txt_dir,
            os.path.basename(img_name).split('.')[0] + '.txt')
        output_viz_path = os.path.join(cls_out_viz_dir,
                                       os.path.basename(img_name))
        if write_file:
            write_output(boxes_list,
                         final_values,
                         final_probs,
                         output_txt_path,
                         prob_thres=cls_ocr_thres)

        if cls_visualize:
            viz_icdar(os.path.join(img_dir, img_name),
                      output_txt_path,
                      output_viz_path,
                      ignor_type=[])
            end_visualize = time.time()
            print('Visualize time:', end_visualize - end_classifier, 'seconds')

    end = time.time()
    speed = (end - begin) / len(list_img_path)
    print('\nTotal processing time:', end - begin, 'seconds. Speed:',
          round(speed, 4), 'second/image')
Esempio n. 10
0
def modify_kie_training_data_by_rules(txt_dir, json_data_path, debug=False):
    list_files = get_list_file_in_folder(txt_dir, ext=['.txt'])

    with open(json_data_path) as json_file:
        data = json.load(json_file)

        list_seller = data['seller']
        list_address = data['address']
    for idx, file in enumerate(list_files):
        # with open(os.path.join(txt_dir, file), mode='r', encoding='utf-8') as f:
        #     anno_list= f.readlines()
        #
        # if 'mcocr_public_145014smasw' not in file:
        #     continue
        # print(idx, file)

        list_icdar_poly = get_list_icdar_poly(os.path.join(txt_dir, file), ignore_kie_type=True)

        modify = False
        has_TOTALCOST_keys = False
        has_TOTALCOST_val = False
        for icdar_pol in list_icdar_poly:
            # fix wrong SELLER
            if icdar_pol.type != 15 and validate_SELLER(list_seller, icdar_pol.value):
                icdar_pol.type = 15
                modify = True

            # fix wrong ADDRESS
            if icdar_pol.type != 16 and validate_ADDRESS(list_address, icdar_pol.value):
                icdar_pol.type = 16
                modify = True

            # fix wrong num ber in ADDRESS or SELLER
            if icdar_pol.type in [15, 16] and validate_TOTAL_COST_amount(icdar_pol.value):
                icdar_pol.type = 1
                modify = True
                # print(idx, file, icdar_pol.value)

            # Fix TIMESTAMP
            if icdar_pol.type != 17 and validate_TIMESTAMP(icdar_pol.value):
                icdar_pol.type = 17
                modify = True

            # Fix TOTALCOST
            # if icdar_pol.type == 18:
            #     kk=1
            if icdar_pol.type == 18 and validate_TOTAL_COST_keys(icdar_pol.value, cer_thres=0.2):
                has_TOTALCOST_keys = True
                for icdar_pol2 in list_icdar_poly:
                    if icdar_pol2.type == 18:
                        if validate_TOTAL_COST_amount(icdar_pol2.value, thres=0.7) or icdar_pol2.value == '':
                            has_TOTALCOST_val = True
                if not has_TOTALCOST_val:
                    modify = True
                    find_TOTALCOST_val_poly(keys_poly=icdar_pol,
                                            list_poly=list_icdar_poly)
                    # img=cv2.imread(os.path.join(img_dir, file.replace('.txt','.jpg')))
                    # cv2.imshow('img', img)
                    # cv2.waitKey(0)

        if modify:
            modify_icdar = ''
            for icdar_pol in list_icdar_poly:
                line = icdar_pol.to_icdar_line(type_map)
                modify_icdar += line + '\n'
            modify_icdar = modify_icdar.rstrip('\n')
            print(idx, file)
            with open(os.path.join(txt_dir, file), mode='w', encoding='utf-8') as f:
                f.write(modify_icdar)