예제 #1
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_classifier = TextClassifier(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        img_list, cls_res, predict_time = text_classifier(img_list)
    except:
        logger.info(traceback.format_exc())
        logger.info(
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               cls_res[ino]))
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
        len(img_list), predict_time))
예제 #2
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)

    try:
        rec_res, predict_time = text_recognizer(img_list)
    except Exception as e:
        print(e)
        logger.info(
            "ERROR!!!! \n"
            "Please read the FAQ: https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
        )
        exit()
    for ino in range(len(img_list)):
        print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
    print("Total predict time for %d images:%.3f" %
          (len(img_list), predict_time))
예제 #3
0
def main(config, device, logger, vdl_writer):
    global_config = config['Global']

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    if hasattr(post_process_class, 'character'):
        config['Architecture']["Head"]['out_channels'] = len(
            getattr(post_process_class, 'character'))

    model = build_model(config['Architecture'])

    init_model(config, model, logger)

    # create data ops
    transforms = []
    use_padding = False
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        if op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image']
        if op_name == "ResizeTableImage":
            use_padding = True
            padding_max_len = op['ResizeTableImage']['max_len']
        transforms.append(op)

    global_config['infer_mode'] = True
    ops = create_operators(transforms, global_config)

    model.eval()
    for file in get_image_file_list(config['Global']['infer_img']):
        logger.info("infer_img: {}".format(file))
        with open(file, 'rb') as f:
            img = f.read()
            data = {'image': img}
        batch = transform(data, ops)
        images = np.expand_dims(batch[0], axis=0)
        images = paddle.to_tensor(images)
        preds = model(images)
        post_result = post_process_class(preds)
        res_html_code = post_result['res_html_code']
        res_loc = post_result['res_loc']
        img = cv2.imread(file)
        imgh, imgw = img.shape[0:2]
        res_loc_final = []
        for rno in range(len(res_loc[0])):
            x0, y0, x1, y1 = res_loc[0][rno]
            left = max(int(imgw * x0), 0)
            top = max(int(imgh * y0), 0)
            right = min(int(imgw * x1), imgw - 1)
            bottom = min(int(imgh * y1), imgh - 1)
            cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
            res_loc_final.append([left, top, right, bottom])
        res_loc_str = json.dumps(res_loc_final)
        logger.info("result: {}, {}".format(res_html_code, res_loc_final))
    logger.info("success!")
예제 #4
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    os.makedirs(args.output, exist_ok=True)

    text_sys = TableSystem(args)
    img_num = len(image_file_list)
    for i, image_file in enumerate(image_file_list):
        logger.info("[{}/{}] {}".format(i, img_num, image_file))
        img, flag = check_and_read_gif(image_file)
        excel_path = os.path.join(
            args.output,
            os.path.basename(image_file).split('.')[0] + '.xlsx')
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.error("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        pred_html = text_sys(img)

        to_excel(pred_html, excel_path)
        logger.info('excel saved to {}'.format(excel_path))
        logger.info(pred_html)
        elapse = time.time() - starttime
        logger.info("Predict time : {:.3f}s".format(elapse))
예제 #5
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []

    # warmup 2 times
    if args.warmup:
        img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
        for i in range(2):
            res = text_recognizer([img] * int(args.rec_batch_num))

    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        rec_res, _ = text_recognizer(img_list)

    except Exception as E:
        logger.info(traceback.format_exc())
        logger.info(E)
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
    if args.benchmark:
        text_recognizer.autolog.report()
예제 #6
0
 def sample_iter_reader():
     if self.mode == 'test':
         image_file_list = get_image_file_list(self.infer_img)
         for single_img in image_file_list:
             img = cv2.imread(single_img)
             if img.shape[-1] == 1 or len(list(img.shape)) == 2:
                 img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
             norm_img = process_image(img, self.image_shape)
             yield norm_img
     else:
         with open(self.label_file_path, "rb") as fin:
             label_infor_list = fin.readlines()
         img_num = len(label_infor_list)
         img_id_list = list(range(img_num))
         random.shuffle(img_id_list)
         for img_id in range(process_id, img_num, self.num_workers):
             label_infor = label_infor_list[img_id_list[img_id]]
             substr = label_infor.decode('utf-8').strip("\n").split(
                 "\t")
             img_path = self.img_set_dir + "/" + substr[0]
             img = cv2.imread(img_path)
             if img.shape[-1] == 1 or len(list(img.shape)) == 2:
                 img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
             if img is None:
                 logger.info("{} does not exist!".format(img_path))
                 continue
             label = substr[1]
             outs = process_image(img, self.image_shape, label,
                                  self.char_ops, self.loss_type,
                                  self.max_text_length)
             if outs is None:
                 continue
             yield outs
예제 #7
0
def main(args):

    if not args.clipboard:
        image_file_list = get_image_file_list(args.image_dir)
        text_sys = TextSystem(args)
        is_visualize = True
        font_path = args.vis_font_path
        drop_score = args.drop_score
        for image_file in image_file_list:
            img, flag = check_and_read_gif(image_file)
            if not flag:
                img = cv2.imread(image_file)
            if img is None:
                logger.info("error in loading image:{}".format(image_file))
                continue
            starttime = time.time()
            dt_boxes, rec_res = text_sys(img)
            elapse = time.time() - starttime
            logger.info("Predict time of %s: %.3fs" % (image_file, elapse))

            out_table(dt_boxes, rec_res)
    else:
        while True:
            instructions = input(
                'Extract Table From Image ("?"/"h" for help,"x" for exit).')
            ins = instructions.strip().lower()
            if ins == 'x':
                break
            try:
                call_model(args)
            except KeyboardInterrupt:
                pass
예제 #8
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    save_folder = args.output
    os.makedirs(save_folder, exist_ok=True)

    structure_sys = OCRSystem(args)
    img_num = len(image_file_list)
    for i, image_file in enumerate(image_file_list):
        logger.info("[{}/{}] {}".format(i, img_num, image_file))
        img, flag = check_and_read_gif(image_file)
        img_name = os.path.basename(image_file).split('.')[0]

        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.error("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        res = structure_sys(img)
        save_structure_res(res, save_folder, img_name)
        draw_img = draw_structure_result(img, res, args.vis_font_path)
        cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
        logger.info('result save to {}'.format(
            os.path.join(save_folder, img_name)))
        elapse = time.time() - starttime
        logger.info("Predict time : {:.3f}s".format(elapse))
예제 #9
0
def main(args):
    args.image_dir = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/doc/imgs_words_en/word_10.png'
    args.rec_char_dict_path = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/ppocr/utils/dict/japan_dict.txt'
    args.rec_model_dir = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/inference/japan_mobile_v2.0_rec_infer'
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)

    try:
        rec_res, predict_time = text_recognizer(img_list)
    except:
        logger.info(traceback.format_exc())
        logger.info(
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
        )
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
        len(img_list), predict_time))
예제 #10
0
def main():
    config = program.load_config(FLAGS.config)
    program.merge_config(FLAGS.opt)
    logger.info(config)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    #     check_gpu(use_gpu)

    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    rec_model = create_module(
        config['Architecture']['function'])(params=config)
    startup_prog = fluid.Program()
    eval_prog = fluid.Program()
    with fluid.program_guard(eval_prog, startup_prog):
        with fluid.unique_name.guard():
            _, outputs = rec_model(mode="test")
            fetch_name_list = list(outputs.keys())
            fetch_varname_list = [outputs[v].name for v in fetch_name_list]
    eval_prog = eval_prog.clone(for_test=True)
    exe.run(startup_prog)

    init_model(config, eval_prog, exe)

    blobs = reader_main(config, 'test')()
    infer_img = config['Global']['infer_img']
    infer_list = get_image_file_list(infer_img)
    max_img_num = len(infer_list)
    if len(infer_list) == 0:
        logger.info("Can not find img in infer_img dir.")
    for i in range(max_img_num):
        logger.info("infer_img:%s" % infer_list[i])
        img = next(blobs)
        predict = exe.run(program=eval_prog,
                          feed={"image": img},
                          fetch_list=fetch_varname_list,
                          return_numpy=False)
        scores = np.array(predict[0])
        label = np.array(predict[1])
        if len(label.shape) != 1:
            label, scores = scores, label
        logger.info('\t scores: {}'.format(scores))
        logger.info('\t label: {}'.format(label))
    # save for inference model
    target_var = []
    for key, values in outputs.items():
        target_var.append(values)

    fluid.io.save_inference_model("./output",
                                  feeded_var_names=['image'],
                                  target_vars=target_var,
                                  executor=exe,
                                  main_program=eval_prog,
                                  model_filename="model",
                                  params_filename="params")
        def sample_iter_reader():
            if self.mode != 'train' and self.infer_img is not None:
                image_file_list = get_image_file_list(self.infer_img)
                for single_img in image_file_list:
                    img = cv2.imread(single_img)
                    if img.shape[-1] == 1 or len(list(img.shape)) == 2:
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
                    norm_img = process_image(img=img,
                                             image_shape=self.image_shape,
                                             char_ops=self.char_ops,
                                             tps=self.use_tps,
                                             infer_mode=True)
                    yield norm_img
            else:
                with open(self.label_file_path, "rb") as fin:
                    label_infor_list = fin.readlines()
                img_num = len(label_infor_list)
                img_id_list = list(range(img_num))
                random.shuffle(img_id_list)
                if sys.platform == "win32" and self.num_workers != 1:
                    print("multiprocess is not fully compatible with Windows."
                          "num_workers will be 1.")
                    self.num_workers = 1
                if self.batch_size * get_device_num(
                ) * self.num_workers > img_num:
                    raise Exception(
                        "The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})"
                        .format(
                            img_num,
                            self.batch_size * get_device_num() *
                            self.num_workers))
                for img_id in range(process_id, img_num, self.num_workers):
                    label_infor = label_infor_list[img_id_list[img_id]]
                    substr = label_infor.decode('utf-8').strip("\n").split(
                        "\t")
                    img_path = self.img_set_dir + "/" + substr[0]
                    img = cv2.imread(img_path)
                    if img is None:
                        logger.info("{} does not exist!".format(img_path))
                        continue
                    if img.shape[-1] == 1 or len(list(img.shape)) == 2:
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

                    label = substr[1]
                    outs = process_image(img=img,
                                         image_shape=self.image_shape,
                                         label=label,
                                         char_ops=self.char_ops,
                                         loss_type=self.loss_type,
                                         max_text_length=self.max_text_length,
                                         distort=self.use_distort)
                    if outs is None:
                        continue
                    yield outs
예제 #12
0
def main():
    global_config = config['Global']

    # build model
    model = build_model(config['Architecture'])

    init_model(config, model, logger)

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        elif op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image', 'shape']
        transforms.append(op)

    ops = create_operators(transforms, global_config)

    save_res_path = config['Global']['save_res_path']
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))

    model.eval()
    with open(save_res_path, "wb") as fout:
        for file in get_image_file_list(config['Global']['infer_img']):
            logger.info("infer_img: {}".format(file))
            with open(file, 'rb') as f:
                img = f.read()
                data = {'image': img}
            batch = transform(data, ops)
            images = np.expand_dims(batch[0], axis=0)
            shape_list = np.expand_dims(batch[1], axis=0)
            images = paddle.to_tensor(images)
            preds = model(images)
            post_result = post_process_class(preds, shape_list)
            points, strs = post_result['points'], post_result['texts']
            # write resule
            dt_boxes_json = []
            for poly, str in zip(points, strs):
                tmp_json = {"transcription": str}
                tmp_json['points'] = poly.tolist()
                dt_boxes_json.append(tmp_json)
            otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
            fout.write(otstr.encode())
            src_img = cv2.imread(file)
            draw_e2e_res(points, strs, config, src_img, file)
    logger.info("success!")
예제 #13
0
def main(args):
    # print(1111)
    image_file_list = get_image_file_list(args.image_dir)
    # print(1111)
    text_sys = TextSystem(args)
    # print(1111)
    is_visualize = True
    font_path = args.vis_font_path
    print(111111)

    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        print(1)
        print(image_file)
        print(1)
        print("Predict time of %s: %.3fs" % (image_file, elapse))

        drop_score = 0.5
        dt_num = len(dt_boxes)
        for dno in range(dt_num):
            text, score = rec_res[dno]
            if score >= drop_score:
                text_str = "%s, %.3f" % (text, score)
                print(text_str)

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save = "./results"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            print("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
예제 #14
0
def main():
    # for com
    args = parse_args()
    image_file_list = get_image_file_list(args.image_dir)
    if len(image_file_list) == 0:
        logger.error('no images find in {}'.format(args.image_dir))
        return
    ocr_engine = PaddleOCR()
    for img_path in image_file_list:
        print(img_path)
        result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec)
        for line in result:
            print(line)
예제 #15
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_sys = TextSystem(args)
    is_visualize = True
    tackle_img_num = 0
    for image_file in image_file_list:
        img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        tackle_img_num += 1
        if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
            text_sys = TextSystem(args)
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        print("Predict time of %s: %.3fs" % (image_file, elapse))
        dt_num = len(dt_boxes)
        dt_boxes_final = []
        txts_final = []  # add
        for dno in range(dt_num):
            text, score = rec_res[dno]
            if score >= 0.5:
                text_str = "%s, %.3f" % (text, score)
                print(text_str)
                dt_boxes_final.append(dt_boxes[dno])
                txts_final.append(rec_res[dno][0])  # add
        label_dic = dict(zip(txts_final, dt_boxes_final))
        print(label_dic)

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr(image,
                                boxes,
                                txts,
                                scores,
                                draw_txt=True,
                                drop_score=0.5)
            draw_img_save = "./inference_results/"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            print("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
예제 #16
0
        def sample_iter_reader():
            if self.mode != 'train' and self.infer_img is not None:
                image_file_list = get_image_file_list(self.infer_img)
                for single_img in image_file_list:
                    img = cv2.imread(single_img)
                    if img.shape[-1] == 1 or len(list(img.shape)) == 2:
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
                    norm_img = process_image(
                        img=img,
                        image_shape=self.image_shape,
                        char_ops=self.char_ops,
                        tps=self.use_tps,
                        infer_mode=True)
                    yield norm_img
            else:
                lmdb_sets = self.load_hierarchical_lmdb_dataset()
                if process_id == 0:
                    self.print_lmdb_sets_info(lmdb_sets)
                cur_index_sets = [1 + process_id] * len(lmdb_sets)
                while True:
                    finish_read_num = 0
                    for dataset_idx in range(len(lmdb_sets)):
                        cur_index = cur_index_sets[dataset_idx]
                        if cur_index > lmdb_sets[dataset_idx]['num_samples']:
                            finish_read_num += 1
                        else:
                            sample_info = self.get_lmdb_sample_info(
                                lmdb_sets[dataset_idx]['txn'], cur_index)
                            cur_index_sets[dataset_idx] += self.num_workers
                            if sample_info is None:
                                continue
                            img, label = sample_info
                            outs = process_image(
                                img=img,
                                image_shape=self.image_shape,
                                label=label,
                                char_ops=self.char_ops,
                                loss_type=self.loss_type,
                                max_text_length=self.max_text_length,
                                distort=self.use_distort)
                            if outs is None:
                                continue
                            yield outs

                    if finish_read_num == len(lmdb_sets):
                        break
                self.close_lmdb_dataset(lmdb_sets)
예제 #17
0
def main():
    global_config = config['Global']

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    if hasattr(post_process_class, 'character'):
        config['Architecture']["Head"]['out_channels'] = len(
            getattr(post_process_class, 'character'))

    model = build_model(config['Architecture'])

    init_model(config, model, logger)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        elif op_name in ['RecResizeImg']:
            op[op_name]['infer_mode'] = True
        elif op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image']
        transforms.append(op)
    global_config['infer_mode'] = True
    ops = create_operators(transforms, global_config)

    model.eval()
    for file in get_image_file_list(config['Global']['infer_img']):
        logger.info("infer_img: {}".format(file))
        with open(file, 'rb') as f:
            img = f.read()
            data = {'image': img}
        batch = transform(data, ops)

        images = np.expand_dims(batch[0], axis=0)
        images = paddle.to_tensor(images)
        preds = model(images)
        post_result = post_process_class(preds)
        for rec_reuslt in post_result:
            logger.info('\t result: {}'.format(rec_reuslt))
    logger.info("success!")
예제 #18
0
def main():
    # for com
    args = parse_args(mMain=True)
    image_file_list = get_image_file_list(args.image_dir)
    if len(image_file_list) == 0:
        logger.error('no images find in {}'.format(args.image_dir))
        return

    ocr_engine = PaddleOCR(**(args.__dict__))
    for img_path in image_file_list:
        print(img_path)
        result = ocr_engine.ocr(img_path,
                                det=args.det,
                                rec=args.rec,
                                cls=args.use_angle_cls)
        if result is not None:
            for line in result:
                print(line)
예제 #19
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_sys = TextSystem(args)
    is_visualize = True
    font_path = args.vis_font_path
    drop_score = args.drop_score
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        logger.info("Predict time of %s: %.3fs" % (image_file, elapse))

        for text, score in rec_res:
            logger.info("{}, {:.3f}".format(text, score))

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save = "./inference_results/"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            if flag:
                image_file = image_file[:-3] + "png"
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            logger.info("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
예제 #20
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    table_structurer = TableStructurer(args)
    count = 0
    total_time = 0
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        structure_res, elapse = table_structurer(img)

        logger.info("result: {}".format(structure_res))

        if count > 0:
            total_time += elapse
        count += 1
        logger.info("Predict time of {}: {}".format(image_file, elapse))
예제 #21
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    total_run_time = 0.0
    total_images_num = 0
    valid_image_file_list = []
    img_list = []
    for idx, image_file in enumerate(image_file_list):
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
        if len(img_list) >= args.rec_batch_num or idx == len(
                image_file_list) - 1:
            try:
                rec_res, predict_time = text_recognizer(img_list)
                total_run_time += predict_time
            except:
                logger.info(traceback.format_exc())
                logger.info(
                    "ERROR!!!! \n"
                    "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
                    "If your model has tps module:  "
                    "TPS does not support variable shape.\n"
                    "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
                )
                exit()
            for ino in range(len(img_list)):
                logger.info("Predicts of {}:{}".format(valid_image_file_list[
                    ino], rec_res[ino]))
            total_images_num += len(valid_image_file_list)
            valid_image_file_list = []
            img_list = []
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
        total_images_num, total_run_time))
    def __call__(self, mode):
        process_function = create_module(self.params['process_function'])(
            self.params)
        batch_size = self.params['test_batch_size_per_card']

        img_list = []
        if mode != "test":
            img_set_dir = self.params['img_set_dir']
            img_name_list_path = self.params['label_file_path']
            with open(img_name_list_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
                    img_name = line.decode().strip("\n").split("\t")[0]
                    img_path = os.path.join(img_set_dir, img_name)
                    img_list.append(img_path)
        else:
            img_path = self.params['infer_img']
            img_list = get_image_file_list(img_path)

        def batch_iter_reader():
            batch_outs = []
            for img_path in img_list:
                img = cv2.imread(img_path)
                if img is None:
                    logger.info("{} does not exist!".format(img_path))
                    continue
                elif len(list(img.shape)) == 2 or img.shape[2] == 1:
                    img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
                outs = process_function(img)
                outs.append(img_path)
                batch_outs.append(outs)
                if len(batch_outs) == batch_size:
                    yield batch_outs
                    batch_outs = []
            if len(batch_outs) != 0:
                yield batch_outs

        return batch_iter_reader
예제 #23
0
def main(image_path):
    image_file_list = get_image_file_list(image_path)
    is_visualize = True
    headers = {"Content-type": "application/json"}
    url = "http://127.0.0.1:9292/ocr/prediction"
    cnt = 0
    total_time = 0
    for image_file in image_file_list:
        img = open(image_file, 'rb').read()
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue

        # 发送HTTP请求
        starttime = time.time()
        data = {"feed": [{"image": cv2_to_base64(img)}], "fetch": ["res"]}
        r = requests.post(url=url, headers=headers, data=json.dumps(data))
        elapse = time.time() - starttime
        total_time += elapse
        logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
        res = r.json()['result']
        logger.info(res)

        if is_visualize:
            draw_img = draw_server_result(image_file, res)
            if draw_img is not None:
                draw_img_save = "./server_results/"
                if not os.path.exists(draw_img_save):
                    os.makedirs(draw_img_save)
                cv2.imwrite(
                    os.path.join(draw_img_save, os.path.basename(image_file)),
                    draw_img[:, :, ::-1])
                logger.info("The visualized image saved in {}".format(
                    os.path.join(draw_img_save, os.path.basename(image_file))))
        cnt += 1
        if cnt % 100 == 0:
            logger.info("{} processed".format(cnt))
    logger.info("avg time cost: {}".format(float(total_time) / cnt))
예제 #24
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_classifier = TextClassifier(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list[:10]:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        img_list, cls_res, predict_time = text_classifier(img_list)
    except Exception as e:
        print(e)
        exit()
    for ino in range(len(img_list)):
        print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
    print("Total predict time for %d images:%.3f" %
          (len(img_list), predict_time))
예제 #25
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_classifier = TextClassifier(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        img_list, cls_res, predict_time = text_classifier(img_list)
    except Exception as E:
        logger.info(traceback.format_exc())
        logger.info(E)
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               cls_res[ino]))
예제 #26
0
def main():
    # for cmd
    args = parse_args(mMain=True)
    image_dir = args.image_dir
    if image_dir.startswith('http'):
        download_with_progressbar(image_dir, 'tmp.jpg')
        image_file_list = ['tmp.jpg']
    else:
        image_file_list = get_image_file_list(args.image_dir)
    if len(image_file_list) == 0:
        logger.error('no images find in {}'.format(args.image_dir))
        return

    ocr_engine = PaddleOCR(**(args.__dict__))
    for img_path in image_file_list:
        logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
        result = ocr_engine.ocr(img_path,
                                det=args.det,
                                rec=args.rec,
                                cls=args.use_angle_cls)
        if result is not None:
            for line in result:
                logger.info(line)
예제 #27
0
def main():
    # for cmd
    args = parse_args(mMain=True)
    image_dir = args.image_dir
    if is_link(image_dir):
        download_with_progressbar(image_dir, 'tmp.jpg')
        image_file_list = ['tmp.jpg']
    else:
        image_file_list = get_image_file_list(args.image_dir)
    if len(image_file_list) == 0:
        logger.error('no images find in {}'.format(args.image_dir))
        return
    if args.type == 'ocr':
        engine = PaddleOCR(**(args.__dict__))
    elif args.type == 'structure':
        engine = PPStructure(**(args.__dict__))
    else:
        raise NotImplementedError

    for img_path in image_file_list:
        img_name = os.path.basename(img_path).split('.')[0]
        logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
        if args.type == 'ocr':
            result = engine.ocr(img_path,
                                det=args.det,
                                rec=args.rec,
                                cls=args.use_angle_cls)
            if result is not None:
                for line in result:
                    logger.info(line)
        elif args.type == 'structure':
            result = engine(img_path)
            save_structure_res(result, args.output, img_name)

            for item in result:
                item.pop('img')
                logger.info(item)
예제 #28
0
    def __call__(self, mode):
        process_function = create_module(self.params['process_function'])(
            self.params)
        batch_size = self.params['test_batch_size_per_card']

        img_list = []
        if mode != "test":
            img_set_dir = self.params['img_set_dir']
            img_name_list_path = self.params['label_file_path']
            with open(img_name_list_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
                    img_name = line.decode().strip("\n").split("\t")[0]
                    img_path = img_set_dir + "/" + img_name
                    img_list.append([img_path, img_name])
        else:
            img_path = self.params['single_img_path']
            img_list = get_image_file_list(img_path)

        def batch_iter_reader():
            batch_outs = []
            for img_path in img_list:
                img = cv2.imread(img_path)
                if img is None:
                    logger.info("load image error:" + img_path)
                    continue
                outs = process_function(img)
                outs.append(img_path)
                batch_outs.append(outs)
                if len(batch_outs) == batch_size:
                    yield batch_outs
                    batch_outs = []
            if len(batch_outs) != 0:
                yield batch_outs

        return batch_iter_reader
예제 #29
0
            preds['f_border'] = outputs[0]
            preds['f_char'] = outputs[1]
            preds['f_direction'] = outputs[2]
            preds['f_score'] = outputs[3]
        else:
            raise NotImplementedError
        post_result = self.postprocess_op(preds, shape_list)
        points, strs = post_result['points'], post_result['strs']
        dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
        elapse = time.time() - starttime
        return dt_boxes, strs, elapse


if __name__ == "__main__":
    args = utility.parse_args()
    image_file_list = get_image_file_list(args.image_dir)
    text_detector = TextE2E(args)
    count = 0
    total_time = 0
    draw_img_save = "./inference_results"
    if not os.path.exists(draw_img_save):
        os.makedirs(draw_img_save)
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        points, strs, elapse = text_detector(img)
        if count > 0:
예제 #30
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    text_sys = TextSystem(args)
    is_visualize = False
    font_path = args.vis_font_path
    drop_score = args.drop_score
    num = 1
    loop_count = 20
    selected_imgs = random.sample(image_file_list, k=20)
    for image_file in selected_imgs:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            # logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        # logger.info("Predict time of %s: %.3fs" % (image_file, elapse))

        for text, score in rec_res:
            logger.info("{}, {:.3f}".format(text, score))

        if args.is_save:
            dataset_dir = './final_results/20/'
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            # 写入到res_text里
            print(image_file)
            path = os.path.join(
                dataset_dir,
                os.path.splitext(os.path.basename(image_file))[0])
            if os.path.exists(path + '.txt'):
                continue
            res_txt = open(path + '.txt', 'w', encoding="utf-8")
            for item in txts:
                res_txt.write(item + '\n')

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save = "./inference_results/"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            logger.info("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))

        num = num + 1
        if num > loop_count:
            break