コード例 #1
0
def rotate_90_infer():
    shutil.rmtree("test_result", ignore_errors=True)
    args = init_args()

    model = DetInfer(args.model_path)
    tic = time.time()
    for name in tqdm(os.listdir(args.img_path)):
        img_path = os.path.join(args.img_path, name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img1 = np.ascontiguousarray(np.rot90(img))  # 逆时针旋转90°
        box_list2, score_list2 = model.predict(img1, is_output_polygon=False)
        box2 = []
        if len(box_list2) > 0:
            for i in box_list2:
                b = rotation_point(img1, -90, point=i)[1]
                # print(b.shape)
                # print(b)
                box2.append(b)  # 顺时针旋转回来
        # if len(box2) > 0:
        #     # print(box2[0])
        #     print(type(box2[0]))
        #     break
        # box2 = np.array(box2)
        img = draw_bbox(img, box2)

        os.makedirs('test_result', exist_ok=True)
        # cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        img = img[:, :, ::-1]
        cv2.imwrite(filename=f'test_result/result_{name}', img=img)
    print(
        f'avg infer image in {(time.time() - tic) / len(os.listdir(args.img_path)):.4f}s')
コード例 #2
0
ファイル: demowithUI.py プロジェクト: ttthomaschan/PytorchOCR
    def recognition_clicked(self):
        # Prepare models
        modeldet_path = '/home/elimen/Data/dbnet_pytorch/checkpoints/ch_det_server_db_res18.pth'
        modelrec_path = '/home/elimen/Data/dbnet_pytorch/checkpoints/ch_rec_server_crnn_res34.pth'
        modeldet = DetInfer(modeldet_path)
        modelrec = RecInfer(modelrec_path)
        self.srcImg = cv2.imread(self.imgPath)
        img_bak = self.srcImg.copy()
        ## Detection
        box_list, score_list = modeldet.predict(self.srcImg,
                                                is_output_polygon=False)
        self.srcImg = cv2.cvtColor(self.srcImg, cv2.COLOR_BGR2RGB)
        self.boxImg = draw_bbox(self.srcImg, box_list)

        res_name = self.imgPath.split('/')[-1].split('.')[1]  #'mt03_bbox.jpg'
        cv2.imwrite(self.resPath + res_name + '_bbox.jpg', self.boxImg)

        ## Recognition
        # '''
        # output the bbox corner and text recognition result
        # '''
        imgcroplist = []
        bbox_cornerlist = []
        txt_file = os.path.join(self.resPath, res_name + '_bbox.txt')
        txt_f = open(txt_file, 'w')
        for i, box in enumerate(box_list):
            pt0, pt1, pt2, pt3 = box

            imgout = img_bak[int(min(pt0[1], pt1[1])) -
                             4:int(max(pt2[1], pt3[1])) + 4,
                             int(min(pt0[0], pt3[0])) -
                             4:int(max(pt1[0], pt2[0])) + 4]
            imgcroplist.append(imgout)

            box_corner = [int(pt0[0]), int(pt0[1]), int(pt2[0]), int(pt2[1])]
            bbox_cornerlist.append(box_corner)
        bbox_cornerlist.reverse()

        self.rec_cont = []
        for i in range(len(imgcroplist) - 1, -1, -1):
            out = modelrec.predict(imgcroplist[i])
            self.rec_cont.append(out[0][0])

            txt_f.write(str(bbox_cornerlist[i]))
            txt_f.write(out[0][0] + '\n')
        txt_f.close()
        '''Recognition and Generation of table'''
        ## Temporarily comment
        # tab_rec = TabRecognition(img_bak)
        # crop_list,height_list, width_list= tab_rec.detnrec()
        # self.resultname = res_name + '.xlsx'
        # generateExcelFile(self.resPath,self.resultname,bbox_cornerlist,self.rec_cont,crop_list,height_list,width_list)

        return box_list, self.rec_cont
コード例 #3
0
def multi_scale_infer():
    shutil.rmtree("test_result", ignore_errors=True)
    args = init_args()

    model = DetInfer(args.model_path)
    tic = time.time()
    for name in tqdm(os.listdir(args.img_path)):
        img_path = os.path.join(args.img_path, name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        box_list1, score_list1 = model.predict(img, is_output_polygon=False)

        img1 = np.ascontiguousarray(np.rot90(img))  # 逆时针旋转90°
        box_list2, score_list2 = model.predict(img1, is_output_polygon=False)
        box2 = []
        if len(box_list2) > 0:
            for i in box_list2:
                box2.append(rotation_point(img1, -90, point=i)[1])  # 顺时针旋转回来
        # img = draw_ocr_box_txt(img, box_list)

        img2 = np.ascontiguousarray(np.rot90(img, -1))  # 顺时针旋转90°
        box_list3, score_list3 = model.predict(img2, is_output_polygon=False)
        box3 = []
        if len(box_list3) > 0:
            for i in box_list3:
                box3.append(rotation_point(img1, 90, point=i)[1])  # 逆时针旋转回来

        box_list = box_list1 + box2 + box3
        score_list = score_list1 + score_list2 + score_list3
        # print(np.array(box_list).shape, np.array(score_list).shape)
        keep = py_cpu_pnms(np.array(box_list), np.array(score_list), thresh=0.1)
        # assert len(box_list) == len(score_list)
        box_list = [box_list[i] for i in keep]
        score_list = [score_list[i] for i in keep]
        write_txt_file(box_list, score_list, img_idx=name.split('.')[0].split('_')[-1])
        # print(len(box_list), box_list[0])
        img = draw_bbox(img, box_list)

        os.makedirs('test_result', exist_ok=True)
        img = img[:, :, ::-1]
        cv2.imwrite(filename=f'test_result/result_{name}', img=img)
    print(f'avg infer image in {(time.time() - tic) / len(os.listdir(args.img_path)):.4f}s')
コード例 #4
0
def single_scale_infer():
    shutil.rmtree("test_result", ignore_errors=True)
    args = init_args()

    model = DetInfer(args.model_path)
    tic = time.time()
    for name in tqdm(os.listdir(args.img_path)):
        img_path = os.path.join(args.img_path, name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        box_list, score_list = model.predict(img, is_output_polygon=False)
        # exit(0)
        # img = draw_ocr_box_txt(img, box_list)
        img = draw_bbox(img, box_list)

        os.makedirs('test_result', exist_ok=True)
        # cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        img = img[:, :, ::-1]
        cv2.imwrite(filename=f'test_result/result_{name}', img=img)
    print(
        f'avg infer image in {(time.time() - tic) / len(os.listdir(args.img_path)):.4f}s')
コード例 #5
0
ファイル: det_infer.py プロジェクト: Bourne-M/PytorchOCR

def init_args():
    import argparse
    parser = argparse.ArgumentParser(description='PytorchOCR infer')
    parser.add_argument('--model_path', required=True, type=str, help='rec model path')
    parser.add_argument('--img_path', required=True, type=str, help='img dir for predict')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    import cv2
    import time
    from matplotlib import pyplot as plt
    from torchocr.utils import draw_bbox

    args = init_args()

    model = DetInfer(args.model_path)
    names = next(os.walk(args.img_path))[2]
    st = time.time()
    for name in names:
        path = os.path.join(args.img_path, name)
        img = cv2.imread(path)
        box_list, score_list = model.predict(img)
        out_path = os.path.join(args.img_path, 'res', name)
        img = draw_bbox(img, box_list)
        cv2.imwrite(out_path[:-4] + '_res.jpg', img)
    print((time.time() - st) / len(names))
コード例 #6
0
if __name__ == '__main__':
    import torch
    from torch.utils.data import DataLoader
    from config.det_train_db_config import config
    from torchocr.utils import show_img, draw_bbox

    from matplotlib import pyplot as plt
    dataset = JsonDataset(config.dataset.train.dataset)
    train_loader = DataLoader(dataset=dataset,
                              batch_size=1,
                              shuffle=True,
                              num_workers=0)
    for i, data in enumerate(tqdm(train_loader)):
        img = data['img']
        shrink_label = data['shrink_map']
        threshold_label = data['threshold_map']

        print(threshold_label.shape, threshold_label.shape, img.shape)
        show_img(img[0].numpy().transpose(1, 2, 0), title='img')
        show_img((shrink_label[0].to(torch.float)).numpy(),
                 title='shrink_label')
        show_img((threshold_label[0].to(torch.float)).numpy(),
                 title='threshold_label')
        img = draw_bbox(img[0].numpy().transpose(1, 2, 0),
                        np.array(data['text_polys']))
        show_img(img, title='draw_bbox')
        plt.show()

        pass