예제 #1
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    os.makedirs(args.output, exist_ok=True)

    text_sys = TableSystem(args)
    img_num = len(image_file_list)
    for i, image_file in enumerate(image_file_list):
        logger.info("[{}/{}] {}".format(i, img_num, image_file))
        img, flag = check_and_read_gif(image_file)
        excel_path = os.path.join(
            args.output,
            os.path.basename(image_file).split('.')[0] + '.xlsx')
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.error("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        pred_html = text_sys(img)

        to_excel(pred_html, excel_path)
        logger.info('excel saved to {}'.format(excel_path))
        logger.info(pred_html)
        elapse = time.time() - starttime
        logger.info("Predict time : {:.3f}s".format(elapse))
예제 #2
0
    def get_data(self,
                 label_infor,
                 is_aug=False,
                 is_crop=False,
                 is_resize=False,
                 is_shrink=False,
                 is_border=False):
        img_path, gt_label = self.convert_label_infor(label_infor)
        imgvalue, flag = check_and_read_gif(img_path)
        if not flag:
            imgvalue = cv2.imread(img_path)
        if imgvalue is None:
            logger.info("{} does not exist!".format(img_path))
            return None
        if len(list(imgvalue.shape)) == 2 or imgvalue.shape[2] == 1:
            imgvalue = cv2.cvtColor(imgvalue, cv2.COLOR_GRAY2BGR)
        data = self.make_data_dict(imgvalue, gt_label)
        if is_aug: data = AugmentData(data)
        if is_crop: data = RandomCropData(data, self.image_shape[1:])
        if is_resize: data = ResizeData(data, self.image_shape[1:])
        if is_shrink: data = MakeShrinkMap(data)
        if is_border: data = MakeBorderMap(data)
        # data = self.NormalizeImage(data)
        # data = self.FilterKeys(data)

        return data
예제 #3
0
def main(args):
    args.image_dir = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/doc/imgs_words_en/word_10.png'
    args.rec_char_dict_path = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/ppocr/utils/dict/japan_dict.txt'
    args.rec_model_dir = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/inference/japan_mobile_v2.0_rec_infer'
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)

    try:
        rec_res, predict_time = text_recognizer(img_list)
    except:
        logger.info(traceback.format_exc())
        logger.info(
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
        )
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
        len(img_list), predict_time))
예제 #4
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    save_folder = args.output
    os.makedirs(save_folder, exist_ok=True)

    structure_sys = OCRSystem(args)
    img_num = len(image_file_list)
    for i, image_file in enumerate(image_file_list):
        logger.info("[{}/{}] {}".format(i, img_num, image_file))
        img, flag = check_and_read_gif(image_file)
        img_name = os.path.basename(image_file).split('.')[0]

        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.error("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        res = structure_sys(img)
        save_structure_res(res, save_folder, img_name)
        draw_img = draw_structure_result(img, res, args.vis_font_path)
        cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
        logger.info('result save to {}'.format(
            os.path.join(save_folder, img_name)))
        elapse = time.time() - starttime
        logger.info("Predict time : {:.3f}s".format(elapse))
예제 #5
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []

    # warmup 2 times
    if args.warmup:
        img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
        for i in range(2):
            res = text_recognizer([img] * int(args.rec_batch_num))

    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        rec_res, _ = text_recognizer(img_list)

    except Exception as E:
        logger.info(traceback.format_exc())
        logger.info(E)
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
    if args.benchmark:
        text_recognizer.autolog.report()
예제 #6
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_classifier = TextClassifier(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        img_list, cls_res, predict_time = text_classifier(img_list)
    except:
        logger.info(traceback.format_exc())
        logger.info(
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               cls_res[ino]))
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
        len(img_list), predict_time))
예제 #7
0
def main(args):

    if not args.clipboard:
        image_file_list = get_image_file_list(args.image_dir)
        text_sys = TextSystem(args)
        is_visualize = True
        font_path = args.vis_font_path
        drop_score = args.drop_score
        for image_file in image_file_list:
            img, flag = check_and_read_gif(image_file)
            if not flag:
                img = cv2.imread(image_file)
            if img is None:
                logger.info("error in loading image:{}".format(image_file))
                continue
            starttime = time.time()
            dt_boxes, rec_res = text_sys(img)
            elapse = time.time() - starttime
            logger.info("Predict time of %s: %.3fs" % (image_file, elapse))

            out_table(dt_boxes, rec_res)
    else:
        while True:
            instructions = input(
                'Extract Table From Image ("?"/"h" for help,"x" for exit).')
            ins = instructions.strip().lower()
            if ins == 'x':
                break
            try:
                call_model(args)
            except KeyboardInterrupt:
                pass
예제 #8
0
 def ocr(self, img, det=True, rec=True):
     """
     ocr with paddleocr
     args:
         img: img for ocr, support ndarray, img_path and list or ndarray
         det: use text detection or not, if false, only rec will be exec. default is True
         rec: use text recognition or not, if false, only det will be exec. default is True
     """
     assert isinstance(img, (np.ndarray, list, str))
     if isinstance(img, str):
         image_file = img
         img, flag = check_and_read_gif(image_file)
         if not flag:
             img = cv2.imread(image_file)
         if img is None:
             logger.error("error in loading image:{}".format(image_file))
             return None
     if det and rec:
         dt_boxes, rec_res = self.__call__(img)
         return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
     elif det and not rec:
         dt_boxes, elapse = self.text_detector(img)
         if dt_boxes is None:
             return None
         return [box.tolist() for box in dt_boxes]
     else:
         if not isinstance(img, list):
             img = [img]
         rec_res, elapse = self.text_recognizer(img)
         return rec_res
예제 #9
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)

    try:
        rec_res, predict_time = text_recognizer(img_list)
    except Exception as e:
        print(e)
        logger.info(
            "ERROR!!!! \n"
            "Please read the FAQ: https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
        )
        exit()
    for ino in range(len(img_list)):
        print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
    print("Total predict time for %d images:%.3f" %
          (len(img_list), predict_time))
예제 #10
0
def main(args):
    # print(1111)
    image_file_list = get_image_file_list(args.image_dir)
    # print(1111)
    text_sys = TextSystem(args)
    # print(1111)
    is_visualize = True
    font_path = args.vis_font_path
    print(111111)

    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        print(1)
        print(image_file)
        print(1)
        print("Predict time of %s: %.3fs" % (image_file, elapse))

        drop_score = 0.5
        dt_num = len(dt_boxes)
        for dno in range(dt_num):
            text, score = rec_res[dno]
            if score >= drop_score:
                text_str = "%s, %.3f" % (text, score)
                print(text_str)

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save = "./results"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            print("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
예제 #11
0
    def ocr(self, img, det=True, rec=True, cls=False):
        """
        ocr with paddleocr
        args:
            img: img for ocr, support ndarray, img_path and list or ndarray
            det: use text detection or not, if false, only rec will be exec. default is True
            rec: use text recognition or not, if false, only det will be exec. default is True
        """
        assert isinstance(img, (np.ndarray, list, str))
        if isinstance(img, list) and det == True:
            logger.error('When input a list of images, det must be false')
            exit(0)
        if cls == False:
            self.use_angle_cls = False
        elif cls == True and self.use_angle_cls == False:
            logger.warning(
                'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
            )

        if isinstance(img, str):
            # download net image
            if img.startswith('http'):
                download_with_progressbar(img, 'tmp.jpg')
                img = 'tmp.jpg'
            image_file = img
            img, flag = check_and_read_gif(image_file)
            if not flag:
                with open(image_file, 'rb') as f:
                    np_arr = np.frombuffer(f.read(), dtype=np.uint8)
                    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
            if img is None:
                logger.error("error in loading image:{}".format(image_file))
                return None
        if isinstance(img, np.ndarray) and len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        if det and rec:
            dt_boxes, rec_res = self.__call__(img)
            return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
        elif det and not rec:
            dt_boxes, elapse = self.text_detector(img)
            if dt_boxes is None:
                return None
            return [box.tolist() for box in dt_boxes]
        else:
            if not isinstance(img, list):
                img = [img]
            if self.use_angle_cls:
                img, cls_res, elapse = self.text_classifier(img)
                if not rec:
                    return cls_res
            rec_res, elapse = self.text_recognizer(img)
            return rec_res
예제 #12
0
 def __call__(self, label_infor):
     img_path, gt_label = self.convert_label_infor(label_infor)
     imgvalue, flag = check_and_read_gif(img_path)
     if not flag:
         imgvalue = cv2.imread(img_path)
     if imgvalue is None:
         logger.info("{} does not exist!".format(img_path))
         return None
     if len(list(imgvalue.shape)) == 2 or imgvalue.shape[2] == 1:
         imgvalue = cv2.cvtColor(imgvalue, cv2.COLOR_GRAY2BGR)
     data = self.make_data_dict(imgvalue, gt_label)
     data = AugmentData(data)
     data = RandomCropData(data, self.image_shape[1:])
     data = MakeShrinkMap(data)
     data = MakeBorderMap(data)
     data = self.NormalizeImage(data)
     data = self.FilterKeys(data)
     return data['image'], data['shrink_map'], data['shrink_mask'], data[
         'threshold_map'], data['threshold_mask']
예제 #13
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_sys = TextSystem(args)
    is_visualize = True
    font_path = args.vis_font_path
    drop_score = args.drop_score
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        logger.info("Predict time of %s: %.3fs" % (image_file, elapse))

        for text, score in rec_res:
            logger.info("{}, {:.3f}".format(text, score))

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save = "./inference_results/"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            if flag:
                image_file = image_file[:-3] + "png"
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            logger.info("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
예제 #14
0
    def __call__(self, img):
        if isinstance(img, str):
            # download net image
            if img.startswith('http'):
                download_with_progressbar(img, 'tmp.jpg')
                img = 'tmp.jpg'
            image_file = img
            img, flag = check_and_read_gif(image_file)
            if not flag:
                with open(image_file, 'rb') as f:
                    np_arr = np.frombuffer(f.read(), dtype=np.uint8)
                    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
            if img is None:
                logger.error("error in loading image:{}".format(image_file))
                return None
        if isinstance(img, np.ndarray) and len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

        res = super().__call__(img)
        return res
예제 #15
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    table_structurer = TableStructurer(args)
    count = 0
    total_time = 0
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        structure_res, elapse = table_structurer(img)

        logger.info("result: {}".format(structure_res))

        if count > 0:
            total_time += elapse
        count += 1
        logger.info("Predict time of {}: {}".format(image_file, elapse))
예제 #16
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_recognizer = TextRecognizer(args)
    total_run_time = 0.0
    total_images_num = 0
    valid_image_file_list = []
    img_list = []
    for idx, image_file in enumerate(image_file_list):
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
        if len(img_list) >= args.rec_batch_num or idx == len(
                image_file_list) - 1:
            try:
                rec_res, predict_time = text_recognizer(img_list)
                total_run_time += predict_time
            except:
                logger.info(traceback.format_exc())
                logger.info(
                    "ERROR!!!! \n"
                    "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
                    "If your model has tps module:  "
                    "TPS does not support variable shape.\n"
                    "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
                )
                exit()
            for ino in range(len(img_list)):
                logger.info("Predicts of {}:{}".format(valid_image_file_list[
                    ino], rec_res[ino]))
            total_images_num += len(valid_image_file_list)
            valid_image_file_list = []
            img_list = []
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
        total_images_num, total_run_time))
예제 #17
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_classifier = TextClassifier(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        img_list, cls_res, predict_time = text_classifier(img_list)
    except Exception as E:
        logger.info(traceback.format_exc())
        logger.info(E)
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               cls_res[ino]))
예제 #18
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_classifier = TextClassifier(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list[:10]:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        img_list, cls_res, predict_time = text_classifier(img_list)
    except Exception as e:
        print(e)
        exit()
    for ino in range(len(img_list)):
        print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
    print("Total predict time for %d images:%.3f" %
          (len(img_list), predict_time))
예제 #19
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    templates = get_templates(args.template_dir)
    text_sys = TextSystem(args)
    is_visualize = True
    tackle_img_num = 0
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        tackle_img_num += 1
        if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
            text_sys = TextSystem(args)
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        print("Predict time of %s: %.3fs" % (image_file, elapse))
        drop_score = 0.5

        match_results = text_match(templates,
                                   dt_boxes,
                                   rec_res,
                                   drop_score=drop_score)
        remove_invalid_results(templates,
                               match_results,
                               dt_boxes,
                               rec_res,
                               drop_score=drop_score)

        ocr_results = []

        for match_result in match_results:
            ocr_result = {}

            for partition_type, match_dict in match_result.items():
                ocr_dict = {}

                if partition_type == 'head':
                    continue

                if partition_type == 'image_type':
                    ocr_dict[partition_type] = match_dict
                    continue

                for partition_id, results in match_dict.items():
                    for result in results:
                        ocr_dict.setdefault(partition_id, [])
                        ocr_patition_dict = {
                            'text': result['text'],
                            # 'points': result['points'].tolist(),
                            'confidence': result['score']
                        }

                        ocr_dict[partition_id].append(ocr_patition_dict)

                ocr_result[partition_type] = ocr_dict

            ocr_results.append(ocr_result)

        print('ocr_results:', ocr_results)

        # dt_num = len(dt_boxes)
        # for dno in range(dt_num):
        #     dt_box = dt_boxes[dno]
        #     text, score = rec_res[dno]
        #     if score >= drop_score:
        #         text_str = "%s, %.3f" % (text, score)
        #         print(text_str, dt_box)

        #         post_process(templates, text, dt_box)

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr(image,
                                boxes,
                                txts,
                                scores,
                                draw_txt=True,
                                drop_score=drop_score)
            draw_img_save = "./inference_results/"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            print("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
예제 #20
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    text_sys = TextSystem(args)
    is_visualize = False
    font_path = args.vis_font_path
    drop_score = args.drop_score
    num = 1
    loop_count = 20
    selected_imgs = random.sample(image_file_list, k=20)
    for image_file in selected_imgs:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            # logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        # logger.info("Predict time of %s: %.3fs" % (image_file, elapse))

        for text, score in rec_res:
            logger.info("{}, {:.3f}".format(text, score))

        if args.is_save:
            dataset_dir = './final_results/20/'
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            # 写入到res_text里
            print(image_file)
            path = os.path.join(
                dataset_dir,
                os.path.splitext(os.path.basename(image_file))[0])
            if os.path.exists(path + '.txt'):
                continue
            res_txt = open(path + '.txt', 'w', encoding="utf-8")
            for item in txts:
                res_txt.write(item + '\n')

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save = "./inference_results/"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            logger.info("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))

        num = num + 1
        if num > loop_count:
            break
예제 #21
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    text_sys = TextSystem(args)
    is_visualize = True
    font_path = args.vis_font_path
    drop_score = args.drop_score

    # warm up 10 times
    if args.warmup:
        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
        for i in range(10):
            res = text_sys(img)

    total_time = 0
    cpu_mem, gpu_mem, gpu_util = 0, 0, 0
    _st = time.time()
    count = 0
    for idx, image_file in enumerate(image_file_list):

        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.debug("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        total_time += elapse

        logger.debug(
            str(idx) + "  Predict time of %s: %.3fs" % (image_file, elapse))
        for text, score in rec_res:
            logger.debug("{}, {:.3f}".format(text, score))

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save_dir = args.draw_img_save_dir
            os.makedirs(draw_img_save_dir, exist_ok=True)
            if flag:
                image_file = image_file[:-3] + "png"
            cv2.imwrite(
                os.path.join(draw_img_save_dir, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            logger.debug("The visualized image saved in {}".format(
                os.path.join(draw_img_save_dir, os.path.basename(image_file))))

    logger.info("The predict total time is {}".format(time.time() - _st))
    if args.benchmark:
        text_sys.text_detector.autolog.report()
        text_sys.text_recognizer.autolog.report()
예제 #22
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_sys = TextSystem(args)
    is_visualize = True
    font_path = args.vis_font_path
    drop_score = args.drop_score
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        logger.info("Predict time of %s: %.3fs" % (image_file, elapse))

        for text, score in rec_res:
            logger.info("{}, {:.3f}".format(text, score))

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes

            # for box in boxes:
            #     xy_sum = np.sum(box, axis=0) / 4.0
            #     cx = xy_sum[0]
            #     cy = xy_sum[1]
            #     degree = np.arcsin((box[1][1] - box[0][1]) / (box[1][0] - box[0][0]))
            #     w = abs(box[0][0] - box[1][0])
            #     h = abs(box[0][1] - box[3][1])
            #     x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w, h, degree / 180 * np.pi)
            #     box[0][0] = x1
            #     box[0][1] = y1
            #     box[1][0] = x2
            #     box[1][1] = y2
            #     box[2][0] = x3
            #     box[2][1] = y3
            #     box[3][0] = x4
            #     box[3][1] = y4

            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            assorted_results = {"text_boxes":
                                    [{'id': i + 1,
                                      'bbox': [float(dt_boxes[i][0][0]), float(dt_boxes[i][0][1]), float(dt_boxes[i][2][0]), float(dt_boxes[i][2][1])],
                                      'text': rec_res[i][0]} for i in range(len(rec_res))],
                                "fields":
                                    [{"field_name": "customer_number",
                                      "value_id": [],
                                      "value_text": [],
                                      "key_id": [],
                                      "key_text": []},
                                     {"field_name": "name",
                                      "value_id": [],
                                      "value_text": [],
                                      "key_id": [],
                                      "key_text": []},
                                     {"field_name": "address",
                                      "value_id": [],
                                      "value_text": [],
                                      "key_id": [],
                                      "key_text": []},
                                     {"field_name": "amount",
                                      "value_id": [],
                                      "value_text": [],
                                      "key_id": [],
                                      "key_text": []},
                                     {"field_name": "date",
                                      "value_id": [],
                                      "value_text": [],
                                      "key_id": [],
                                      "key_text": []},
                                     {"field_name": "content",
                                      "value_id": [],
                                      "value_text": [],
                                      "key_id": [],
                                      "key_text": []}],
                                "global_attributes": {
                                    "file_id": image_file.split('/')[-1]}
                                }
            with open(image_file + '.json', 'w', encoding='utf-8') as outfile:
                json.dump(assorted_results, outfile)
            #res = trainTicket.trainTicket(assorted_results, img=image)
            #res = res.res
            ##compare_img = clip_ground_truth_and_draw_txt(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), boxes, txts, scores, font_path=font_path)

            draw_img = draw_ocr_box_txt(
                image,
                boxes,
                txts,
                scores,
                drop_score=drop_score,
                font_path=font_path)
            draw_img_save = "./inference_results/"
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            # cv2.imwrite(
            #     os.path.join(draw_img_save, os.path.basename(image_file)),
            #     compare_img)
            logger.info("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
예제 #23
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    # print(image_file_list)
    text_sys = TextSystem(args)
    is_visualize = True
    font_path = args.vis_font_path
    message = """"""
    textList = []
    height_sum = 0
    blank = 20  # 设置每张图片之间的像素距离
    if not os.path.exists('/usr/web/outputs'):
        os.makedirs('/usr/web/outputs')
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()

        dt_boxes, rec_res = text_sys(img)

        elapse = time.time() - starttime
        print("Predict time of %s: %.3fs" % (image_file, elapse))

        # print("dt_boxes:", dt_boxes)
        # print()
        # print("rec_res:", rec_res)

        # drop_score = 0.5
        # dt_num = len(dt_boxes)
        # for dno in range(dt_num):
        #     text, score = rec_res[dno]
        #     if score >= drop_score:
        #         text_str = "%s, %.3f" % (text, score)
        #         print(text_str)

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            w = image.width
            h = image.height

            image = image.resize((w, h))
            image = np.array(image)

            boxes = dt_boxes  # 坐标(四个点)
            txts = [rec_res[i][0] for i in range(len(rec_res))]  # 字

            print(boxes)
            print(txts)
            # scores = [rec_res[i][1] for i in range(len(rec_res))]

            print(image.shape)
            for t in range(len(boxes)):
                image[int(round(boxes[t][0][1])):int(round(boxes[t][3][1])),
                int(round(boxes[t][0][0])):int(round(boxes[t][1][0])), :] = \
                    image[int(round(boxes[t][1][1]))][int(round(boxes[t][1][0]))]

            img = Image.fromarray(image)
            savename = os.path.join("/usr/web/modify_input", image_file.split("/")[-1])
            img.save(savename)
            webname = os.path.join("./modify_input", image_file.split("/")[-1]) 
            textList.append(str(height_sum))
            textList.append(str(webname))
            message = message + """
        <div style="position:absolute; left:0px; top:%spx;">
            <img src="%s"/>
        </div>"""
            for x in range(len(txts)):
                textList.append(str((boxes[x][2][1] - boxes[x][0][1]) * 0.8))  # 字号修正
                textList.append(str(boxes[x][0][0]))
                textList.append(str(boxes[x][0][1] + height_sum))
                textList.append(txts[x])
                message = message + """
        <p style="font-size:%spx">
            <a style="position:absolute; left:%spx; top:%spx;">%s</a>
        </p>"""

            height_sum = height_sum + h + blank

        start = """<!DOCTYPE html>
<head>
    <meta charset="UTF-8">
</head>
<body>"""
    end = """
</body>
</html>
"""
    message = (start + message + end) % tuple(textList)

    GEN_HTML = "/usr/web/test.html"
    # 打开文件,准备写入
    f = open(GEN_HTML, 'w')
    # 写入文件
    f.write(message)
    # 关闭文件
    f.close()

    # # 生成pdf文档用于返回
    confg = pdfkit.configuration(wkhtmltopdf='/root/ai_competition/wkhtmltopdf/usr/local/bin/wkhtmltopdf')
    # # 这里指定一下wkhtmltopdf的路径,这就是我为啥在前面让记住这个路径
    pdfkit.from_url('39.100.154.163/test.html', '/usr/web/test.pdf', configuration=confg)
예제 #24
0
def main(args):
    for root, dirs, files in os.walk(args.image_dir):
        for file in files:
            with open(root + "/" + file, "r") as f:
                imgpath = root + "/" + file
                image_file_list = get_image_file_list(imgpath)
                text_sys = TextSystem(args)
                is_visualize = True
                font_path = args.vis_font_path
                for image_file in image_file_list:
                    img, flag = check_and_read_gif(image_file)
                    if not flag:
                        img = cv2.imread(image_file)
                    if img is None:
                        logger.info(
                            "error in loading image:{}".format(image_file))
                        continue
                    starttime = time.time()
                    # img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                    # img_filter = cv2.bilateralFilter(img_gray, 5, 100, 100)
                    # img_filter = cv2.blur(img_gray, (5, 5))
                    # img_filter = cv2.GaussianBlur(img_gray, (5, 5), 0)
                    # img_filter = cv2.medianBlur(img_gray, 5)
                    # img_filter = cv2.cvtColor(img_filter, cv2.COLOR_GRAY2BGR)
                    # dt_boxes, rec_res = text_sys(img_filter)
                    dt_boxes, rec_res = text_sys(img)
                    roi = ()
                    name_box = np.empty(shape=(4, 2))
                    for i, val in enumerate(rec_res):
                        if "姓名" in val[0]:
                            if len(val[0]) < 3:
                                name_box = dt_boxes[i + 1]
                                rec_res.pop(i + 1)
                                dt_boxes.pop(i + 1)
                                break
                            else:
                                name_box = dt_boxes[i]
                                rec_res.pop(i)
                                dt_boxes.pop(i)
                                break
                        if "名" in val[0]:
                            if len(val[0]) <= 2:
                                name_box = dt_boxes[i + 1]
                                rec_res.pop(i + 1)
                                dt_boxes.pop(i + 1)
                                break
                            elif len(val[0]) <= 6:
                                name_box = dt_boxes[i]
                                rec_res.pop(i)
                                dt_boxes.pop(i)
                                break
                    roi = (int(name_box[0][0]), int(name_box[0][1]),
                           int(name_box[2][0]), int(name_box[2][1]))
                    elapse = time.time() - starttime
                    print("Predict time of %s: %.3fs" % (image_file, elapse))

                    drop_score = 0.5

                    json_img_save = "./inference_results_json/" + \
                        root.replace(args.image_dir+"\\", "")
                    if not os.path.exists(json_img_save):
                        os.makedirs(json_img_save)
                    with open(json_img_save + "/" + file.replace(".jpg", "") +
                              ".json",
                              'w',
                              encoding='utf-8') as file_obj:
                        ans_json = {'data': [{'str': i[0]} for i in rec_res]}
                        json.dump(ans_json,
                                  file_obj,
                                  indent=4,
                                  ensure_ascii=False)

                    if is_visualize:
                        image = Image.fromarray(
                            cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
                        boxes = dt_boxes
                        txts = [rec_res[i][0] for i in range(len(rec_res))]
                        scores = [rec_res[i][1] for i in range(len(rec_res))]

                        draw_img = draw_ocr_box_txt(image,
                                                    boxes,
                                                    txts,
                                                    scores,
                                                    drop_score=args.drop_score,
                                                    font_path=font_path)
                        draw_img.paste((0, 0, 0), roi)
                        draw_img = np.array(draw_img)
                        draw_img_save = "./inference_results/" + \
                            root.replace(args.image_dir+"\\", "")
                        if not os.path.exists(draw_img_save):
                            os.makedirs(draw_img_save)
                        cv2.imwrite(
                            os.path.join(draw_img_save,
                                         os.path.basename(image_file)),
                            draw_img[:, :, ::-1])
                        print("The visualized image saved in {}".format(
                            os.path.join(draw_img_save,
                                         os.path.basename(image_file))))
예제 #25
0
        dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
        elapse = time.time() - starttime
        return dt_boxes, strs, elapse


if __name__ == "__main__":
    args = utility.parse_args()
    image_file_list = get_image_file_list(args.image_dir)
    text_detector = TextE2E(args)
    count = 0
    total_time = 0
    draw_img_save = "./inference_results"
    if not os.path.exists(draw_img_save):
        os.makedirs(draw_img_save)
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        points, strs, elapse = text_detector(img)
        if count > 0:
            total_time += elapse
        count += 1
        logger.info("Predict time of {}: {}".format(image_file, elapse))
        src_im = utility.draw_e2e_res(points, strs, image_file)
        img_name_pure = os.path.split(image_file)[-1]
        img_path = os.path.join(draw_img_save,
                                "e2e_res_{}".format(img_name_pure))
        cv2.imwrite(img_path, src_im)
예제 #26
0
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    image_file_list = image_file_list[args.process_id::args.total_process_num]
    text_sys = TextSystem(args)
    is_visualize = args.is_visualize
    font_path = args.vis_font_path
    drop_score = args.drop_score

    # warm up 10 times
    if args.warmup:
        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
        for i in range(10):
            res = text_sys(img)

    total_time = 0
    cpu_mem, gpu_mem, gpu_util = 0, 0, 0
    _st = time.time()
    count = 0
    save_res = []
    for idx, image_file in enumerate(image_file_list):

        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.debug("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        total_time += elapse

        # save results
        preds = []
        dt_num = len(dt_boxes)
        for dno in range(dt_num):
            text, score = rec_res[dno]
            if score >= drop_score:
                preds.append({
                    "transcription": text,
                    "points": np.array(dt_boxes[dno]).tolist()
                })
                text_str = "%s, %.3f" % (text, score)
        save_res.append(image_file + '\t' +
                        json.dumps(preds, ensure_ascii=False) + '\n')

        # print predicted results
        logger.debug(
            str(idx) + "  Predict time of %s: %.3fs" % (image_file, elapse))
        for text, score in rec_res:
            logger.debug("{}, {:.3f}".format(text, score))

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

            draw_img = draw_ocr_box_txt(image,
                                        boxes,
                                        txts,
                                        scores,
                                        drop_score=drop_score,
                                        font_path=font_path)
            draw_img_save_dir = args.draw_img_save_dir
            os.makedirs(draw_img_save_dir, exist_ok=True)
            if flag:
                image_file = image_file[:-3] + "png"
            cv2.imwrite(
                os.path.join(draw_img_save_dir, os.path.basename(image_file)),
                draw_img[:, :, ::-1])
            logger.debug("The visualized image saved in {}".format(
                os.path.join(draw_img_save_dir, os.path.basename(image_file))))

    # The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt")
    save_results_to_txt(save_res, args.draw_img_save_dir)

    logger.info("The predict total time is {}".format(time.time() - _st))
    if args.benchmark:
        text_sys.text_detector.autolog.report()
        text_sys.text_recognizer.autolog.report()