def rotate_90_infer(): shutil.rmtree("test_result", ignore_errors=True) args = init_args() model = DetInfer(args.model_path) tic = time.time() for name in tqdm(os.listdir(args.img_path)): img_path = os.path.join(args.img_path, name) img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img1 = np.ascontiguousarray(np.rot90(img)) # 逆时针旋转90° box_list2, score_list2 = model.predict(img1, is_output_polygon=False) box2 = [] if len(box_list2) > 0: for i in box_list2: b = rotation_point(img1, -90, point=i)[1] # print(b.shape) # print(b) box2.append(b) # 顺时针旋转回来 # if len(box2) > 0: # # print(box2[0]) # print(type(box2[0])) # break # box2 = np.array(box2) img = draw_bbox(img, box2) os.makedirs('test_result', exist_ok=True) # cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img = img[:, :, ::-1] cv2.imwrite(filename=f'test_result/result_{name}', img=img) print( f'avg infer image in {(time.time() - tic) / len(os.listdir(args.img_path)):.4f}s')
def recognition_clicked(self): # Prepare models modeldet_path = '/home/elimen/Data/dbnet_pytorch/checkpoints/ch_det_server_db_res18.pth' modelrec_path = '/home/elimen/Data/dbnet_pytorch/checkpoints/ch_rec_server_crnn_res34.pth' modeldet = DetInfer(modeldet_path) modelrec = RecInfer(modelrec_path) self.srcImg = cv2.imread(self.imgPath) img_bak = self.srcImg.copy() ## Detection box_list, score_list = modeldet.predict(self.srcImg, is_output_polygon=False) self.srcImg = cv2.cvtColor(self.srcImg, cv2.COLOR_BGR2RGB) self.boxImg = draw_bbox(self.srcImg, box_list) res_name = self.imgPath.split('/')[-1].split('.')[1] #'mt03_bbox.jpg' cv2.imwrite(self.resPath + res_name + '_bbox.jpg', self.boxImg) ## Recognition # ''' # output the bbox corner and text recognition result # ''' imgcroplist = [] bbox_cornerlist = [] txt_file = os.path.join(self.resPath, res_name + '_bbox.txt') txt_f = open(txt_file, 'w') for i, box in enumerate(box_list): pt0, pt1, pt2, pt3 = box imgout = img_bak[int(min(pt0[1], pt1[1])) - 4:int(max(pt2[1], pt3[1])) + 4, int(min(pt0[0], pt3[0])) - 4:int(max(pt1[0], pt2[0])) + 4] imgcroplist.append(imgout) box_corner = [int(pt0[0]), int(pt0[1]), int(pt2[0]), int(pt2[1])] bbox_cornerlist.append(box_corner) bbox_cornerlist.reverse() self.rec_cont = [] for i in range(len(imgcroplist) - 1, -1, -1): out = modelrec.predict(imgcroplist[i]) self.rec_cont.append(out[0][0]) txt_f.write(str(bbox_cornerlist[i])) txt_f.write(out[0][0] + '\n') txt_f.close() '''Recognition and Generation of table''' ## Temporarily comment # tab_rec = TabRecognition(img_bak) # crop_list,height_list, width_list= tab_rec.detnrec() # self.resultname = res_name + '.xlsx' # generateExcelFile(self.resPath,self.resultname,bbox_cornerlist,self.rec_cont,crop_list,height_list,width_list) return box_list, self.rec_cont
def multi_scale_infer(): shutil.rmtree("test_result", ignore_errors=True) args = init_args() model = DetInfer(args.model_path) tic = time.time() for name in tqdm(os.listdir(args.img_path)): img_path = os.path.join(args.img_path, name) img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) box_list1, score_list1 = model.predict(img, is_output_polygon=False) img1 = np.ascontiguousarray(np.rot90(img)) # 逆时针旋转90° box_list2, score_list2 = model.predict(img1, is_output_polygon=False) box2 = [] if len(box_list2) > 0: for i in box_list2: box2.append(rotation_point(img1, -90, point=i)[1]) # 顺时针旋转回来 # img = draw_ocr_box_txt(img, box_list) img2 = np.ascontiguousarray(np.rot90(img, -1)) # 顺时针旋转90° box_list3, score_list3 = model.predict(img2, is_output_polygon=False) box3 = [] if len(box_list3) > 0: for i in box_list3: box3.append(rotation_point(img1, 90, point=i)[1]) # 逆时针旋转回来 box_list = box_list1 + box2 + box3 score_list = score_list1 + score_list2 + score_list3 # print(np.array(box_list).shape, np.array(score_list).shape) keep = py_cpu_pnms(np.array(box_list), np.array(score_list), thresh=0.1) # assert len(box_list) == len(score_list) box_list = [box_list[i] for i in keep] score_list = [score_list[i] for i in keep] write_txt_file(box_list, score_list, img_idx=name.split('.')[0].split('_')[-1]) # print(len(box_list), box_list[0]) img = draw_bbox(img, box_list) os.makedirs('test_result', exist_ok=True) img = img[:, :, ::-1] cv2.imwrite(filename=f'test_result/result_{name}', img=img) print(f'avg infer image in {(time.time() - tic) / len(os.listdir(args.img_path)):.4f}s')
def single_scale_infer(): shutil.rmtree("test_result", ignore_errors=True) args = init_args() model = DetInfer(args.model_path) tic = time.time() for name in tqdm(os.listdir(args.img_path)): img_path = os.path.join(args.img_path, name) img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) box_list, score_list = model.predict(img, is_output_polygon=False) # exit(0) # img = draw_ocr_box_txt(img, box_list) img = draw_bbox(img, box_list) os.makedirs('test_result', exist_ok=True) # cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img = img[:, :, ::-1] cv2.imwrite(filename=f'test_result/result_{name}', img=img) print( f'avg infer image in {(time.time() - tic) / len(os.listdir(args.img_path)):.4f}s')
def init_args(): import argparse parser = argparse.ArgumentParser(description='PytorchOCR infer') parser.add_argument('--model_path', required=True, type=str, help='rec model path') parser.add_argument('--img_path', required=True, type=str, help='img dir for predict') args = parser.parse_args() return args if __name__ == '__main__': import cv2 import time from matplotlib import pyplot as plt from torchocr.utils import draw_bbox args = init_args() model = DetInfer(args.model_path) names = next(os.walk(args.img_path))[2] st = time.time() for name in names: path = os.path.join(args.img_path, name) img = cv2.imread(path) box_list, score_list = model.predict(img) out_path = os.path.join(args.img_path, 'res', name) img = draw_bbox(img, box_list) cv2.imwrite(out_path[:-4] + '_res.jpg', img) print((time.time() - st) / len(names))
if __name__ == '__main__': import torch from torch.utils.data import DataLoader from config.det_train_db_config import config from torchocr.utils import show_img, draw_bbox from matplotlib import pyplot as plt dataset = JsonDataset(config.dataset.train.dataset) train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0) for i, data in enumerate(tqdm(train_loader)): img = data['img'] shrink_label = data['shrink_map'] threshold_label = data['threshold_map'] print(threshold_label.shape, threshold_label.shape, img.shape) show_img(img[0].numpy().transpose(1, 2, 0), title='img') show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label') show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label') img = draw_bbox(img[0].numpy().transpose(1, 2, 0), np.array(data['text_polys'])) show_img(img, title='draw_bbox') plt.show() pass