def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list[args.process_id::args.total_process_num] os.makedirs(args.output, exist_ok=True) text_sys = TableSystem(args) img_num = len(image_file_list) for i, image_file in enumerate(image_file_list): logger.info("[{}/{}] {}".format(i, img_num, image_file)) img, flag = check_and_read_gif(image_file) excel_path = os.path.join( args.output, os.path.basename(image_file).split('.')[0] + '.xlsx') if not flag: img = cv2.imread(image_file) if img is None: logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() pred_html = text_sys(img) to_excel(pred_html, excel_path) logger.info('excel saved to {}'.format(excel_path)) logger.info(pred_html) elapse = time.time() - starttime logger.info("Predict time : {:.3f}s".format(elapse))
def get_data(self, label_infor, is_aug=False, is_crop=False, is_resize=False, is_shrink=False, is_border=False): img_path, gt_label = self.convert_label_infor(label_infor) imgvalue, flag = check_and_read_gif(img_path) if not flag: imgvalue = cv2.imread(img_path) if imgvalue is None: logger.info("{} does not exist!".format(img_path)) return None if len(list(imgvalue.shape)) == 2 or imgvalue.shape[2] == 1: imgvalue = cv2.cvtColor(imgvalue, cv2.COLOR_GRAY2BGR) data = self.make_data_dict(imgvalue, gt_label) if is_aug: data = AugmentData(data) if is_crop: data = RandomCropData(data, self.image_shape[1:]) if is_resize: data = ResizeData(data, self.image_shape[1:]) if is_shrink: data = MakeShrinkMap(data) if is_border: data = MakeBorderMap(data) # data = self.NormalizeImage(data) # data = self.FilterKeys(data) return data
def main(args): args.image_dir = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/doc/imgs_words_en/word_10.png' args.rec_char_dict_path = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/ppocr/utils/dict/japan_dict.txt' args.rec_model_dir = '/home/duycuong/PycharmProjects/research_py3/MC_OCR/mc_ocr/text_detector/PaddleOCR/inference/japan_mobile_v2.0_rec_infer' image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) img_list.append(img) try: rec_res, predict_time = text_recognizer(img_list) except: logger.info(traceback.format_exc()) logger.info( "ERROR!!!! \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "If your model has tps module: " "TPS does not support variable shape.\n" "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' " ) exit() for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ino])) logger.info("Total predict time for {} images, cost: {:.3f}".format( len(img_list), predict_time))
def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list image_file_list = image_file_list[args.process_id::args.total_process_num] save_folder = args.output os.makedirs(save_folder, exist_ok=True) structure_sys = OCRSystem(args) img_num = len(image_file_list) for i, image_file in enumerate(image_file_list): logger.info("[{}/{}] {}".format(i, img_num, image_file)) img, flag = check_and_read_gif(image_file) img_name = os.path.basename(image_file).split('.')[0] if not flag: img = cv2.imread(image_file) if img is None: logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() res = structure_sys(img) save_structure_res(res, save_folder, img_name) draw_img = draw_structure_result(img, res, args.vis_font_path) cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img) logger.info('result save to {}'.format( os.path.join(save_folder, img_name))) elapse = time.time() - starttime logger.info("Predict time : {:.3f}s".format(elapse))
def main(args): image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] # warmup 2 times if args.warmup: img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8) for i in range(2): res = text_recognizer([img] * int(args.rec_batch_num)) for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) img_list.append(img) try: rec_res, _ = text_recognizer(img_list) except Exception as E: logger.info(traceback.format_exc()) logger.info(E) exit() for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ino])) if args.benchmark: text_recognizer.autolog.report()
def main(args): image_file_list = get_image_file_list(args.image_dir) text_classifier = TextClassifier(args) valid_image_file_list = [] img_list = [] for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) img_list.append(img) try: img_list, cls_res, predict_time = text_classifier(img_list) except: logger.info(traceback.format_exc()) logger.info( "ERROR!!!! \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "If your model has tps module: " "TPS does not support variable shape.\n" "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ino])) logger.info("Total predict time for {} images, cost: {:.3f}".format( len(img_list), predict_time))
def main(args): if not args.clipboard: image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) out_table(dt_boxes, rec_res) else: while True: instructions = input( 'Extract Table From Image ("?"/"h" for help,"x" for exit).') ins = instructions.strip().lower() if ins == 'x': break try: call_model(args) except KeyboardInterrupt: pass
def ocr(self, img, det=True, rec=True): """ ocr with paddleocr args: img: img for ocr, support ndarray, img_path and list or ndarray det: use text detection or not, if false, only rec will be exec. default is True rec: use text recognition or not, if false, only det will be exec. default is True """ assert isinstance(img, (np.ndarray, list, str)) if isinstance(img, str): image_file = img img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.error("error in loading image:{}".format(image_file)) return None if det and rec: dt_boxes, rec_res = self.__call__(img) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] elif det and not rec: dt_boxes, elapse = self.text_detector(img) if dt_boxes is None: return None return [box.tolist() for box in dt_boxes] else: if not isinstance(img, list): img = [img] rec_res, elapse = self.text_recognizer(img) return rec_res
def main(args): image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) img_list.append(img) try: rec_res, predict_time = text_recognizer(img_list) except Exception as e: print(e) logger.info( "ERROR!!!! \n" "Please read the FAQ: https://github.com/PaddlePaddle/PaddleOCR#faq \n" "If your model has tps module: " "TPS does not support variable shape.\n" "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' " ) exit() for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Total predict time for %d images:%.3f" % (len(img_list), predict_time))
def main(args): # print(1111) image_file_list = get_image_file_list(args.image_dir) # print(1111) text_sys = TextSystem(args) # print(1111) is_visualize = True font_path = args.vis_font_path print(111111) for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime print(1) print(image_file) print(1) print("Predict time of %s: %.3fs" % (image_file, elapse)) drop_score = 0.5 dt_num = len(dt_boxes) for dno in range(dt_num): text, score = rec_res[dno] if score >= drop_score: text_str = "%s, %.3f" % (text, score) print(text_str) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr_box_txt(image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) draw_img_save = "./results" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) print("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file))))
def ocr(self, img, det=True, rec=True, cls=False): """ ocr with paddleocr args: img: img for ocr, support ndarray, img_path and list or ndarray det: use text detection or not, if false, only rec will be exec. default is True rec: use text recognition or not, if false, only det will be exec. default is True """ assert isinstance(img, (np.ndarray, list, str)) if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) if cls == False: self.use_angle_cls = False elif cls == True and self.use_angle_cls == False: logger.warning( 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' ) if isinstance(img, str): # download net image if img.startswith('http'): download_with_progressbar(img, 'tmp.jpg') img = 'tmp.jpg' image_file = img img, flag = check_and_read_gif(image_file) if not flag: with open(image_file, 'rb') as f: np_arr = np.frombuffer(f.read(), dtype=np.uint8) img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) if img is None: logger.error("error in loading image:{}".format(image_file)) return None if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if det and rec: dt_boxes, rec_res = self.__call__(img) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] elif det and not rec: dt_boxes, elapse = self.text_detector(img) if dt_boxes is None: return None return [box.tolist() for box in dt_boxes] else: if not isinstance(img, list): img = [img] if self.use_angle_cls: img, cls_res, elapse = self.text_classifier(img) if not rec: return cls_res rec_res, elapse = self.text_recognizer(img) return rec_res
def __call__(self, label_infor): img_path, gt_label = self.convert_label_infor(label_infor) imgvalue, flag = check_and_read_gif(img_path) if not flag: imgvalue = cv2.imread(img_path) if imgvalue is None: logger.info("{} does not exist!".format(img_path)) return None if len(list(imgvalue.shape)) == 2 or imgvalue.shape[2] == 1: imgvalue = cv2.cvtColor(imgvalue, cv2.COLOR_GRAY2BGR) data = self.make_data_dict(imgvalue, gt_label) data = AugmentData(data) data = RandomCropData(data, self.image_shape[1:]) data = MakeShrinkMap(data) data = MakeBorderMap(data) data = self.NormalizeImage(data) data = self.FilterKeys(data) return data['image'], data['shrink_map'], data['shrink_mask'], data[ 'threshold_map'], data['threshold_mask']
def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: logger.info("{}, {:.3f}".format(text, score)) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr_box_txt(image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) if flag: image_file = image_file[:-3] + "png" cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) logger.info("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file))))
def __call__(self, img): if isinstance(img, str): # download net image if img.startswith('http'): download_with_progressbar(img, 'tmp.jpg') img = 'tmp.jpg' image_file = img img, flag = check_and_read_gif(image_file) if not flag: with open(image_file, 'rb') as f: np_arr = np.frombuffer(f.read(), dtype=np.uint8) img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) if img is None: logger.error("error in loading image:{}".format(image_file)) return None if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) res = super().__call__(img) return res
def main(args): image_file_list = get_image_file_list(args.image_dir) table_structurer = TableStructurer(args) count = 0 total_time = 0 for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue structure_res, elapse = table_structurer(img) logger.info("result: {}".format(structure_res)) if count > 0: total_time += elapse count += 1 logger.info("Predict time of {}: {}".format(image_file, elapse))
def main(args): image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) total_run_time = 0.0 total_images_num = 0 valid_image_file_list = [] img_list = [] for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) img_list.append(img) if len(img_list) >= args.rec_batch_num or idx == len( image_file_list) - 1: try: rec_res, predict_time = text_recognizer(img_list) total_run_time += predict_time except: logger.info(traceback.format_exc()) logger.info( "ERROR!!!! \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "If your model has tps module: " "TPS does not support variable shape.\n" "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' " ) exit() for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ ino], rec_res[ino])) total_images_num += len(valid_image_file_list) valid_image_file_list = [] img_list = [] logger.info("Total predict time for {} images, cost: {:.3f}".format( total_images_num, total_run_time))
def main(args): image_file_list = get_image_file_list(args.image_dir) text_classifier = TextClassifier(args) valid_image_file_list = [] img_list = [] for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) img_list.append(img) try: img_list, cls_res, predict_time = text_classifier(img_list) except Exception as E: logger.info(traceback.format_exc()) logger.info(E) exit() for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ino]))
def main(args): image_file_list = get_image_file_list(args.image_dir) text_classifier = TextClassifier(args) valid_image_file_list = [] img_list = [] for image_file in image_file_list[:10]: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) img_list.append(img) try: img_list, cls_res, predict_time = text_classifier(img_list) except Exception as e: print(e) exit() for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino])) print("Total predict time for %d images:%.3f" % (len(img_list), predict_time))
def main(args): image_file_list = get_image_file_list(args.image_dir) templates = get_templates(args.template_dir) text_sys = TextSystem(args) is_visualize = True tackle_img_num = 0 for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() tackle_img_num += 1 if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: text_sys = TextSystem(args) dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime print("Predict time of %s: %.3fs" % (image_file, elapse)) drop_score = 0.5 match_results = text_match(templates, dt_boxes, rec_res, drop_score=drop_score) remove_invalid_results(templates, match_results, dt_boxes, rec_res, drop_score=drop_score) ocr_results = [] for match_result in match_results: ocr_result = {} for partition_type, match_dict in match_result.items(): ocr_dict = {} if partition_type == 'head': continue if partition_type == 'image_type': ocr_dict[partition_type] = match_dict continue for partition_id, results in match_dict.items(): for result in results: ocr_dict.setdefault(partition_id, []) ocr_patition_dict = { 'text': result['text'], # 'points': result['points'].tolist(), 'confidence': result['score'] } ocr_dict[partition_id].append(ocr_patition_dict) ocr_result[partition_type] = ocr_dict ocr_results.append(ocr_result) print('ocr_results:', ocr_results) # dt_num = len(dt_boxes) # for dno in range(dt_num): # dt_box = dt_boxes[dno] # text, score = rec_res[dno] # if score >= drop_score: # text_str = "%s, %.3f" % (text, score) # print(text_str, dt_box) # post_process(templates, text, dt_box) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=drop_score) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) print("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file))))
def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list[args.process_id::args.total_process_num] text_sys = TextSystem(args) is_visualize = False font_path = args.vis_font_path drop_score = args.drop_score num = 1 loop_count = 20 selected_imgs = random.sample(image_file_list, k=20) for image_file in selected_imgs: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: # logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime # logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: logger.info("{}, {:.3f}".format(text, score)) if args.is_save: dataset_dir = './final_results/20/' txts = [rec_res[i][0] for i in range(len(rec_res))] # 写入到res_text里 print(image_file) path = os.path.join( dataset_dir, os.path.splitext(os.path.basename(image_file))[0]) if os.path.exists(path + '.txt'): continue res_txt = open(path + '.txt', 'w', encoding="utf-8") for item in txts: res_txt.write(item + '\n') if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr_box_txt(image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) logger.info("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file)))) num = num + 1 if num > loop_count: break
def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list[args.process_id::args.total_process_num] text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score # warm up 10 times if args.warmup: img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) for i in range(10): res = text_sys(img) total_time = 0 cpu_mem, gpu_mem, gpu_util = 0, 0, 0 _st = time.time() count = 0 for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.debug("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime total_time += elapse logger.debug( str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: logger.debug("{}, {:.3f}".format(text, score)) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr_box_txt(image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) draw_img_save_dir = args.draw_img_save_dir os.makedirs(draw_img_save_dir, exist_ok=True) if flag: image_file = image_file[:-3] + "png" cv2.imwrite( os.path.join(draw_img_save_dir, os.path.basename(image_file)), draw_img[:, :, ::-1]) logger.debug("The visualized image saved in {}".format( os.path.join(draw_img_save_dir, os.path.basename(image_file)))) logger.info("The predict total time is {}".format(time.time() - _st)) if args.benchmark: text_sys.text_detector.autolog.report() text_sys.text_recognizer.autolog.report()
def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: logger.info("{}, {:.3f}".format(text, score)) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes # for box in boxes: # xy_sum = np.sum(box, axis=0) / 4.0 # cx = xy_sum[0] # cy = xy_sum[1] # degree = np.arcsin((box[1][1] - box[0][1]) / (box[1][0] - box[0][0])) # w = abs(box[0][0] - box[1][0]) # h = abs(box[0][1] - box[3][1]) # x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w, h, degree / 180 * np.pi) # box[0][0] = x1 # box[0][1] = y1 # box[1][0] = x2 # box[1][1] = y2 # box[2][0] = x3 # box[2][1] = y3 # box[3][0] = x4 # box[3][1] = y4 txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] assorted_results = {"text_boxes": [{'id': i + 1, 'bbox': [float(dt_boxes[i][0][0]), float(dt_boxes[i][0][1]), float(dt_boxes[i][2][0]), float(dt_boxes[i][2][1])], 'text': rec_res[i][0]} for i in range(len(rec_res))], "fields": [{"field_name": "customer_number", "value_id": [], "value_text": [], "key_id": [], "key_text": []}, {"field_name": "name", "value_id": [], "value_text": [], "key_id": [], "key_text": []}, {"field_name": "address", "value_id": [], "value_text": [], "key_id": [], "key_text": []}, {"field_name": "amount", "value_id": [], "value_text": [], "key_id": [], "key_text": []}, {"field_name": "date", "value_id": [], "value_text": [], "key_id": [], "key_text": []}, {"field_name": "content", "value_id": [], "value_text": [], "key_id": [], "key_text": []}], "global_attributes": { "file_id": image_file.split('/')[-1]} } with open(image_file + '.json', 'w', encoding='utf-8') as outfile: json.dump(assorted_results, outfile) #res = trainTicket.trainTicket(assorted_results, img=image) #res = res.res ##compare_img = clip_ground_truth_and_draw_txt(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), boxes, txts, scores, font_path=font_path) draw_img = draw_ocr_box_txt( image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) # cv2.imwrite( # os.path.join(draw_img_save, os.path.basename(image_file)), # compare_img) logger.info("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file))))
def main(args): image_file_list = get_image_file_list(args.image_dir) # print(image_file_list) text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path message = """""" textList = [] height_sum = 0 blank = 20 # 设置每张图片之间的像素距离 if not os.path.exists('/usr/web/outputs'): os.makedirs('/usr/web/outputs') for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime print("Predict time of %s: %.3fs" % (image_file, elapse)) # print("dt_boxes:", dt_boxes) # print() # print("rec_res:", rec_res) # drop_score = 0.5 # dt_num = len(dt_boxes) # for dno in range(dt_num): # text, score = rec_res[dno] # if score >= drop_score: # text_str = "%s, %.3f" % (text, score) # print(text_str) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) w = image.width h = image.height image = image.resize((w, h)) image = np.array(image) boxes = dt_boxes # 坐标(四个点) txts = [rec_res[i][0] for i in range(len(rec_res))] # 字 print(boxes) print(txts) # scores = [rec_res[i][1] for i in range(len(rec_res))] print(image.shape) for t in range(len(boxes)): image[int(round(boxes[t][0][1])):int(round(boxes[t][3][1])), int(round(boxes[t][0][0])):int(round(boxes[t][1][0])), :] = \ image[int(round(boxes[t][1][1]))][int(round(boxes[t][1][0]))] img = Image.fromarray(image) savename = os.path.join("/usr/web/modify_input", image_file.split("/")[-1]) img.save(savename) webname = os.path.join("./modify_input", image_file.split("/")[-1]) textList.append(str(height_sum)) textList.append(str(webname)) message = message + """ <div style="position:absolute; left:0px; top:%spx;"> <img src="%s"/> </div>""" for x in range(len(txts)): textList.append(str((boxes[x][2][1] - boxes[x][0][1]) * 0.8)) # 字号修正 textList.append(str(boxes[x][0][0])) textList.append(str(boxes[x][0][1] + height_sum)) textList.append(txts[x]) message = message + """ <p style="font-size:%spx"> <a style="position:absolute; left:%spx; top:%spx;">%s</a> </p>""" height_sum = height_sum + h + blank start = """<!DOCTYPE html> <head> <meta charset="UTF-8"> </head> <body>""" end = """ </body> </html> """ message = (start + message + end) % tuple(textList) GEN_HTML = "/usr/web/test.html" # 打开文件,准备写入 f = open(GEN_HTML, 'w') # 写入文件 f.write(message) # 关闭文件 f.close() # # 生成pdf文档用于返回 confg = pdfkit.configuration(wkhtmltopdf='/root/ai_competition/wkhtmltopdf/usr/local/bin/wkhtmltopdf') # # 这里指定一下wkhtmltopdf的路径,这就是我为啥在前面让记住这个路径 pdfkit.from_url('39.100.154.163/test.html', '/usr/web/test.pdf', configuration=confg)
def main(args): for root, dirs, files in os.walk(args.image_dir): for file in files: with open(root + "/" + file, "r") as f: imgpath = root + "/" + file image_file_list = get_image_file_list(imgpath) text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info( "error in loading image:{}".format(image_file)) continue starttime = time.time() # img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # img_filter = cv2.bilateralFilter(img_gray, 5, 100, 100) # img_filter = cv2.blur(img_gray, (5, 5)) # img_filter = cv2.GaussianBlur(img_gray, (5, 5), 0) # img_filter = cv2.medianBlur(img_gray, 5) # img_filter = cv2.cvtColor(img_filter, cv2.COLOR_GRAY2BGR) # dt_boxes, rec_res = text_sys(img_filter) dt_boxes, rec_res = text_sys(img) roi = () name_box = np.empty(shape=(4, 2)) for i, val in enumerate(rec_res): if "姓名" in val[0]: if len(val[0]) < 3: name_box = dt_boxes[i + 1] rec_res.pop(i + 1) dt_boxes.pop(i + 1) break else: name_box = dt_boxes[i] rec_res.pop(i) dt_boxes.pop(i) break if "名" in val[0]: if len(val[0]) <= 2: name_box = dt_boxes[i + 1] rec_res.pop(i + 1) dt_boxes.pop(i + 1) break elif len(val[0]) <= 6: name_box = dt_boxes[i] rec_res.pop(i) dt_boxes.pop(i) break roi = (int(name_box[0][0]), int(name_box[0][1]), int(name_box[2][0]), int(name_box[2][1])) elapse = time.time() - starttime print("Predict time of %s: %.3fs" % (image_file, elapse)) drop_score = 0.5 json_img_save = "./inference_results_json/" + \ root.replace(args.image_dir+"\\", "") if not os.path.exists(json_img_save): os.makedirs(json_img_save) with open(json_img_save + "/" + file.replace(".jpg", "") + ".json", 'w', encoding='utf-8') as file_obj: ans_json = {'data': [{'str': i[0]} for i in rec_res]} json.dump(ans_json, file_obj, indent=4, ensure_ascii=False) if is_visualize: image = Image.fromarray( cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr_box_txt(image, boxes, txts, scores, drop_score=args.drop_score, font_path=font_path) draw_img.paste((0, 0, 0), roi) draw_img = np.array(draw_img) draw_img_save = "./inference_results/" + \ root.replace(args.image_dir+"\\", "") if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) print("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file))))
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) elapse = time.time() - starttime return dt_boxes, strs, elapse if __name__ == "__main__": args = utility.parse_args() image_file_list = get_image_file_list(args.image_dir) text_detector = TextE2E(args) count = 0 total_time = 0 draw_img_save = "./inference_results" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue points, strs, elapse = text_detector(img) if count > 0: total_time += elapse count += 1 logger.info("Predict time of {}: {}".format(image_file, elapse)) src_im = utility.draw_e2e_res(points, strs, image_file) img_name_pure = os.path.split(image_file)[-1] img_path = os.path.join(draw_img_save, "e2e_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im)
def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list[args.process_id::args.total_process_num] text_sys = TextSystem(args) is_visualize = args.is_visualize font_path = args.vis_font_path drop_score = args.drop_score # warm up 10 times if args.warmup: img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) for i in range(10): res = text_sys(img) total_time = 0 cpu_mem, gpu_mem, gpu_util = 0, 0, 0 _st = time.time() count = 0 save_res = [] for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.debug("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime total_time += elapse # save results preds = [] dt_num = len(dt_boxes) for dno in range(dt_num): text, score = rec_res[dno] if score >= drop_score: preds.append({ "transcription": text, "points": np.array(dt_boxes[dno]).tolist() }) text_str = "%s, %.3f" % (text, score) save_res.append(image_file + '\t' + json.dumps(preds, ensure_ascii=False) + '\n') # print predicted results logger.debug( str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: logger.debug("{}, {:.3f}".format(text, score)) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr_box_txt(image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) draw_img_save_dir = args.draw_img_save_dir os.makedirs(draw_img_save_dir, exist_ok=True) if flag: image_file = image_file[:-3] + "png" cv2.imwrite( os.path.join(draw_img_save_dir, os.path.basename(image_file)), draw_img[:, :, ::-1]) logger.debug("The visualized image saved in {}".format( os.path.join(draw_img_save_dir, os.path.basename(image_file)))) # The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt") save_results_to_txt(save_res, args.draw_img_save_dir) logger.info("The predict total time is {}".format(time.time() - _st)) if args.benchmark: text_sys.text_detector.autolog.report() text_sys.text_recognizer.autolog.report()