示例#1
0
 def __init__(
     self,
     config,
 ) -> None:
     super(ocr, self).__init__()
     self.config = config
     config_base = Cfg.load_config_from_file("config/base.yml")
     config = Cfg.load_config_from_file(self.config)
     config_base.update(config)
     config = config_base
     config['vocab'] = character
     self.text_r = Predictor(config)
示例#2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default="config/vgg-seq2seq.yml",
                        help='config path ')
    # parser.add_argument('--checkpoint', required=False, help='your checkpoint')

    args = parser.parse_args()
    logger = logging.getLogger(__name__)

    config = Cfg.load_config_from_file(args.config, download_base=False)
    logger.info("Loaded config from {}".format(args.config))
    # print('-- CONFIG --')
    dataset_params = {
        'name':
        'hw_word',
        'data_root':
        './DATA',
        'is_padding':
        True,
        'image_max_width':
        100,
        'train_lmdb': [
            'train_hw_word', 'hw_word_9k_good', 'hw_word_50k_dict_3k',
            'valid_hw_word', 'hw_word_70k_dict_full_filter'
        ],
        'valid_lmdb':
        'test_hw_word'
    }
    config['monitor']['log_dir'] = './logs/hw_word_seq2seq_finetuning_240k'

    trainer_params = {
        'batch_size': 32,
        'print_every': 200,
        'valid_every': 5 * 200,
        'iters': 150000,
        'metrics': 5000,
        'pretrained': './logs/hw_word_seq2seq_finetuning_170k_v2/best.pt',
        'resume_from': None,
        'is_finetuning': False
    }

    config['aug']['masked_language_model'] = False

    # optim_params = {
    #     'max_lr': 0.00001
    # }
    # config['optimizer'].update(optim_params)

    config['trainer'].update(trainer_params)
    # config['trainer']['resume_from'] = './logs/hw_small_finetuning/last.pt'
    config['dataset'].update(dataset_params)

    print(config.pretty_text())
    # print(config)
    trainer = Trainer(config, pretrained=False)
    # trainer.visualize_dataset()
    trainer.train()
示例#3
0
def load_recognition_model():
  #chuan bi ocr predict model
  config = Cfg.load_config_from_file('./vietocr/config.yml')
  config['weights'] = "./models/transformerocr.pth"
  config['cnn']['pretrained']=False
  config['device'] = 'cuda:0'
  config['predictor']['beamsearch']=False
  recognizer = Predictor(config)
  return recognizer
示例#4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--img', required=True, help='foo help')
    parser.add_argument('--config', required=True, help='foo help')

    args = parser.parse_args()
    config_base = Cfg.load_config_from_file("config/base.yml")
    config = Cfg.load_config_from_file(args.config)
    config_base.update(config)
    config = config_base

    config['vocab'] = character

    detector = Predictor(config)

    img = Image.open(args.img)
    s = detector.predict(img)

    print(s)
示例#5
0
文件: train.py 项目: lzmisscc/vietocr
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', help='see example at ')
    parser.add_argument('--checkpoint', help='your checkpoint')

    args = parser.parse_args()
    config_base = Cfg.load_config_from_file("config/base.yml")
    config = Cfg.load_config_from_file(args.config)
    config_base.update(config)
    config = config_base

    config['vocab'] = character
    trainer = Trainer(config, pretrained=False)

    # args.checkpoint = config.trainer["checkpoint"]
    # if args.checkpoint:
    #    trainer.load_checkpoint(args.checkpoint)
    #    logging.info(f"Load checkpoint form {args.checkpoint}....")

    trainer.train()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--img', required=True, help='foo help')
    parser.add_argument('--config', required=True, help='foo help')

    args = parser.parse_args()
    config = Cfg.load_config_from_file(args.config)

    detector = TextDetector(config)

    img = Image.open(args.img)
    s = detector.predict(img)

    print(s)
示例#7
0
def predict_file():
    config_path = './logs/hw_word_seq2seq/config.yml'
    config = Cfg.load_config_from_file(config_path, download_base=False)

    config['weights'] = './logs/hw_word_seq2seq_finetuning/best.pt'

    print(config.pretty_text())

    detector = Predictor(config)

    detector.gen_annotations(
        './DATA/data_verifier/hw_word_15k_labels.txt',
        './DATA/data_verifier/hw_word_15k_labels_preds.txt',
        data_root='./DATA/data_verifier')
示例#8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True, help='see example at ')
    parser.add_argument('--checkpoint', required=False, help='your checkpoint')

    args = parser.parse_args()
    config = Cfg.load_config_from_file(args.config)

    trainer = Trainer(config)

    if args.checkpoint:
        trainer.load_checkpoint(args.checkpoint)

    trainer.train()
示例#9
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='./logs/hw_word_seq2seq/config.yml')
    parser.add_argument('--weight',
                        type=str,
                        default='./logs/hw_word_seq2seq/best.pt')
    parser.add_argument('--img', type=str, default=None, required=True)
    args = parser.parse_args()

    config = Cfg.load_config_from_file(args.config, download_base=False)

    config['weights'] = args.weight

    print(config.pretty_text())

    detector = Predictor(config)
    if os.path.isdir(args.img):
        img_paths = os.listdir(args.img)
        for img_path in img_paths:
            try:
                img = Image.open(args.img + '/' + img_path)
            except:
                continue
            t1 = time.time()
            s, prob = detector.predict(img, return_prob=True)
            print('Text in {} is:\t {} | prob: {:.2f} | times: {:.2f}'.format(
                img_path, s, prob,
                time.time() - t1))
    else:
        t1 = time.time()
        img = Image.open(args.img)
        s, prob = detector.predict(img, return_prob=True)
        print('Text in {} is:\t {} | prob: {:.2f} | times: {:.2f}'.format(
            args.img, s, prob,
            time.time() - t1))