Beispiel #1
0
def main():
    args = parse_args()
    cfg_path = args.config
    output = args.output
    cfg = Config.fromfile(cfg_path)
    cfg.model.pretrained = None
    model = RecInfer(cfg, args)
    img_list = load_data(args.img_path)
    font_ttf = "test/STKAITI.TTF"  # 可视化字体类型
    font = ImageFont.truetype(font_ttf, 20)  # 字体与字体大小

    output = args.output

    if not os.path.exists(output):
        os.makedirs(output)

    for file in img_list:
        img = cv2.imread(file)
        base_name = os.path.basename(file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        rec_result = model.predict(img)
        rec_str, rec_score, rec_score_list = rec_result[0]
        out_path = os.path.join(output, base_name)
        img_dst = visual_result(img, rec_str, label=None, font=font)
        cv2.imwrite(out_path, img_dst)
Beispiel #2
0
def main():
    import cv2
    from matplotlib import pyplot as plt
    from torchocr.utils.vis import draw_bbox

    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)
    cfg.model.pretrained = None
    # 通用配置
    model = DetInfer(cfg, args)

    img_list = load_data(args.img_path)

    output = args.output

    if not os.path.exists(output):
        os.makedirs(output)

    for file in tqdm(img_list):
        ori_img = cv2.imread(file)
        base_name = os.path.basename(file)
        img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        # try:
        box_list, score_list = model.predict(img)
        rec_path = os.path.join(output, base_name)
        if len(box_list) > 0:
            res_img = draw_bbox(ori_img,
                                box_list,
                                color=(0, 0, 255),
                                thickness=2)
            cv2.imwrite(rec_path, res_img)
        else:
            shutil.copy(file, rec_path)
Beispiel #3
0
def main():
    import sys, os
    sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
    from torchocr.models import build_model
    from torchocr.utils.config_util import Config
    from torchocr.utils.checkpoints import load_checkpoint, save_checkpoint

    cfg_path = 'config/det/dbnet/61_hw_repb2_dbnet.py'
    model_path = 'work_dirs/61_hw_repb2_dbnet/det.pth'

    cfg = Config.fromfile(cfg_path)
    cfg.model.pretrained = None

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # build model
    train_model = build_model(cfg.model)
    load_checkpoint(train_model, model_path, map_location=device)
    train_model = train_model.to(device)

    cfg.model.backbone.is_deploy = True
    deploy_model = build_model(cfg.model)
    deploy_model = deploy_model.to(device)

    deploy_weights = whole_model_convert(train_model, deploy_model)
    save_checkpoint(deploy_weights, filepath='db_repvgg.pth')
Beispiel #4
0
def main():
    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)

    # 通用配置
    global_config = cfg.options

    # build model
    model = build_model(cfg.model)
    device, gpu_ids = select_device(global_config.gpu_ids)
    load_checkpoint(model, args.model_path,map_location=device)
    model = model.to(device)

    model.device = device

    eval_dataset = build_dataset(cfg.test_data.dataset)
    eval_loader = build_dataloader(eval_dataset, loader_cfg=cfg.test_data.loader)
    # build postprocess
    postprocess = build_postprocess(cfg.postprocess)
    # build metric
    metric = build_metrics(cfg.metric)

    result_metirc = eval(model, eval_loader, postprocess, metric)
    print(result_metirc)
Beispiel #5
0
def main():
    args = parse_args()

    font_ttf = "test/STKAITI.TTF"  # 可视化字体类型
    font = ImageFont.truetype(font_ttf, 20)  # 字体与字体大小

    output = args.output
    det_cfg = Config.fromfile(args.det_config)
    rec_cfg = Config.fromfile(args.rec_config)
    det_cfg.model.pretrained = None
    det_cfg.model.pretrained = None

    rec_model = RecInfer(rec_cfg, args.rec_weights)
    det_model = DetInfer(det_cfg, args.det_weights)

    if not os.path.exists(output):
        os.makedirs(output)

    img_list = load_data(args.img_path)

    for file in tqdm(img_list):

        ori_img = cv2.imread(file)
        base_name = os.path.basename(file)
        img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        boxes_list, scores_list = det_model.predict(img)

        boxes_list = transform_polys_to_bboxs(boxes_list)

        output_path = os.path.join(output, base_name)

        res_img = ori_img.copy()

        ocr_dst = []

        for bbox in boxes_list:
            x1, y1, x2, y2 = bbox
            crop_img = ori_img[y1:y2, x1:x2]
            # one line
            rec_dst = rec_model.predict(crop_img)[0]
            pred_str, pred_score, pred_score_char = rec_dst
            ocr_dst.append((bbox, pred_str))

        res_img = vis_ocr_result(res_img, ocr_dst, font=font)

        cv2.imwrite(output_path, res_img)
Beispiel #6
0
def main():
    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)

    # set pretrained model None
    cfg.model.pretrained = None

    # build postprocess
    postprocess = build_postprocess(cfg.postprocess)
    # for rec cal head number
    if hasattr(postprocess, 'character'):
        char_num = len(getattr(postprocess, 'character'))
        cfg.model.head.n_class = char_num

    eval_dataset = build_dataset(cfg.test_data.dataset)
    eval_loader = build_dataloader(eval_dataset,
                                   loader_cfg=cfg.test_data.loader)
    # build metric
    metric = build_metrics(cfg.metric)

    mode = args.mode
    if mode == 'torch':
        # build model
        model = build_model(cfg.model)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        load_checkpoint(model, args.model_path, map_location=device)
        if args.simple:
            (filepath, filename) = os.path.split(args.model_path)
            simple_model_path = os.path.join(filepath,
                                             'sim_{}'.format(filename))
            save_checkpoint(model, simple_model_path)

        model = model.to(device)
        model.device = device
        result_metirc = eval(model, eval_loader, postprocess, metric)

    elif mode == 'engine':

        engine_path = args.engine_path
        model = TRTModel(engine_path)
        result_metirc = engine_eval(model, eval_loader, postprocess, metric)

    print(result_metirc)
Beispiel #7
0
def test_vis():
    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)
    model = RecInfer(cfg, args)
    output = args.output
    font_ttf = "test/STKAITI.TTF"  # 可视化字体类型
    font = ImageFont.truetype(font_ttf, 20)  # 字体与字体大小
    data_infos = load_data_by_txt(args.img_path)

    if not os.path.exists(output):
        os.makedirs(output)

    true_path = os.path.join(output, 'true_img')
    false_path = os.path.join(output, 'error_img')
    if not os.path.exists(true_path):
        os.makedirs(true_path)
    if not os.path.exists(false_path):
        os.makedirs(false_path)

    for idx, data in enumerate(tqdm(data_infos)):
        file = data['img_path']
        label = data['label']
        img = cv2.imread(file)
        base_name = os.path.basename(file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        rec_result = model.predict(img)
        rec_str, rec_score, rec_score_list = rec_result[0]

        img_dst = visual_result(img, rec_str, label=label, font=font)

        if label == rec_str:
            out_path = os.path.join(true_path, base_name)
        else:
            out_path = os.path.join(false_path, base_name)
            with open(os.path.join(output, 'error.txt'),
                      'a+',
                      encoding='utf-8') as fw:
                fw.write(file + '\n')
        cv2.imwrite(out_path, img_dst)
Beispiel #8
0
def test_vis():
    import cv2
    from matplotlib import pyplot as plt
    from torchocr.utils.vis import draw_bbox

    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)
    cfg.model.pretrained = None
    # 通用配置
    model = DetInfer(cfg, args)

    data_infos = load_data_by_txt(args.img_path)
    output = args.output

    if not os.path.exists(output):
        os.makedirs(output)

    for idx, data in enumerate(tqdm(data_infos)):
        file = data['img_path']
        label = data['label']
        ori_img = cv2.imread(file)
        base_name = os.path.basename(file)
        img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        # try:
        box_list, score_list = model.predict(img)
        rec_path = os.path.join(output, base_name)

        if len(box_list) > 0:
            res_img = draw_bbox(ori_img,
                                box_list,
                                color=(0, 0, 255),
                                thickness=2)
            res_img = draw_bbox(res_img, label, color=(255, 0, 0), thickness=3)
        else:
            res_img = draw_bbox(ori_img, label, color=(255, 0, 0), thickness=3)

        cv2.imwrite(rec_path, res_img)
Beispiel #9
0
def main():
    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)

    global_config = cfg.options  # 通用配置
    # local_rank = 0 is logger
    global_config['local_rank'] = args.local_rank
    # amp train
    if args.amp:
        global_config['is_amp'] = True
    else:
        global_config['is_amp'] = False

    # ema train
    if args.ema:
        global_config['is_ema'] = True
    else:
        global_config['is_ema'] = False

    # set cudnn_benchmark,如数据size一致能加快训练
    if global_config.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if global_config.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        global_config.work_dir = osp.join(
            './work_dirs',
            osp.splitext(osp.basename(args.config))[0])

    # create work_dir
    file_util.mkdir_or_exist(global_config.work_dir)
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(global_config.work_dir, '{}.log'.format(timestamp))
    logger = get_logger(name='ocr', log_file=log_file)

    # # log env info
    if args.local_rank == 0:
        env_info_dict = collect_env()
        env_info = '\n'.join([('{}: {}'.format(k, v))
                              for k, v in env_info_dict.items()])
        dash_line = '-' * 60 + '\n'
        logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                    dash_line)
        ## log some basic info
        logger.info('Config:\n{}'.format(cfg.text))
        # set random seeds
        logger.info('Set random seed to {}, deterministic: {}'.format(
            global_config.seed, args.deterministic))

    # set random seed
    set_random_seed(global_config.seed, deterministic=args.deterministic)

    # select device
    # dist init
    if torch.cuda.device_count() > 1 and args.distributed:
        device = init_dist(launcher='pytorch',
                           backend='nccl',
                           rank=args.local_rank)
        global_config['distributed'] = True
    else:
        device, gpu_ids = select_device(global_config.gpu_ids)
        global_config.gpu_ids = gpu_ids
        global_config['distributed'] = False

    # build train dataset
    train_dataset = build_dataset(cfg.train_data.dataset)
    train_loader = build_dataloader(train_dataset,
                                    loader_cfg=cfg.train_data.loader,
                                    distributed=global_config['distributed'])

    # if is eval , build eval dataloader,postprocess,metric
    # 移动到前面,由于rec-head的输出需要用postprocess计算
    if global_config.is_eval:
        eval_dataset = build_dataset(cfg.test_data.dataset)
        eval_loader = build_dataloader(
            eval_dataset,
            loader_cfg=cfg.test_data.loader,
            distributed=global_config['distributed'])
        # build postprocess
        postprocess = build_postprocess(cfg.postprocess)
        # build metric
        metric = build_metrics(cfg.metric)
    else:
        eval_loader = None
        postprocess = None
        metric = None

    # for rec cal head number
    if hasattr(postprocess, 'character'):
        char_num = len(getattr(postprocess, 'character'))
        cfg.model.head.n_class = char_num

    # build model
    model = build_model(cfg.model)
    model = model.to(device)

    # set model to device
    if device.type != 'cpu' and torch.cuda.device_count(
    ) > 1 and global_config['distributed'] == True:
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank)
        device = torch.device('cuda', args.local_rank)
        is_cuda = True
    elif device.type != 'cpu' and global_config[
            'distributed'] == False and len(gpu_ids) >= 1:
        model = nn.DataParallel(model, device_ids=global_config.gpu_ids)
        model.gpu_ids = gpu_ids
        is_cuda = True
    else:
        is_cuda = False

    global_config['is_cuda'] = is_cuda

    model.device = device

    # build optimizer
    optimizer = build_optimizer(cfg.optimizer, model)
    # build lr_scheduler
    lr_scheduler = build_lr_scheduler(cfg.lr_scheduler, optimizer)
    # build loss
    criterion = build_loss(cfg.loss).to(device)

    runner = TrainRunner(global_config, model, optimizer, lr_scheduler,
                         postprocess, criterion, train_loader, eval_loader,
                         metric, logger)

    # # Resume
    if global_config.resume_from is not None and args.resume:
        runner.resume(global_config.resume_from, map_location=device)

    if global_config.load_from is not None:
        runner.load_checkpoint(global_config.load_from, map_location=device)

    runner.run()
Beispiel #10
0
                        required=True,
                        type=str,
                        help='config of model')
    parser.add_argument('--model_path',
                        required=True,
                        type=str,
                        help='rec model path')
    parser.add_argument('--img_path',
                        required=True,
                        type=str,
                        help='img path for predict')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    import cv2

    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)
    # 通用配置
    global_config = cfg.options

    img = cv2.imread(args.img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    model = RecInfer(cfg, args.model_path)
    rec_reuslt = model.predict(img)

    print(rec_reuslt)
Beispiel #11
0
def main():
    args = parse_args()
    cfg_path = args.config
    cfg = Config.fromfile(cfg_path)
    # set pretrained model None
    cfg.model.pretrained = None

    # build postprocess
    postprocess = build_postprocess(cfg.postprocess)
    # for rec cal head number
    if hasattr(postprocess, 'character'):
        char_num = len(getattr(postprocess, 'character'))
        cfg.model.head.n_class = char_num

    # use config build model
    model = build_model(cfg.model)

    # set weights to model and set model to device/eval()
    model_path = args.weights
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    load_checkpoint(model, model_path, map_location=device)
    model = model.to(device)
    model.eval()

    onnx_output = args.onnx_output
    input_shape = args.input_shape
    input_shape = eval(input_shape)
    mode = args.mode


    # transform torch model to onnx model
    input_names = ['input']
    output_names = ['output']

    # input shape
    input_data = torch.randn(input_shape).to(device)

    if args.is_dynamic:
        if mode == 'rec':
            # #rec
            dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}
        elif mode == 'det':
            ## det
            dynamic_axes = {"input": {0: "batch_size", 2: 'height', 3: 'width'},
                            "output": {0: "batch_size", 2: 'height', 3: 'width'}}

    else:
        dynamic_axes = None

    onnx_model_name = torch2onnx(
        model=model,
        dummy_input=input_data,
        onnx_model_name=onnx_output,
        input_names=input_names,
        output_names=output_names,
        opset_version=12,
        is_dynamic=args.is_dynamic,
        dynamic_axes=dynamic_axes
    )

    onnx_model = onnx.load(onnx_model_name)
    # check that the model converted fine
    onnx.checker.check_model(onnx_model)
    onnx.helper.printable_graph(onnx_model.graph)
    print("Model was successfully converted to ONNX format.")
    print("It was saved to", onnx_model_name)