def __init__( self, config, ) -> None: super(ocr, self).__init__() self.config = config config_base = Cfg.load_config_from_file("config/base.yml") config = Cfg.load_config_from_file(self.config) config_base.update(config) config = config_base config['vocab'] = character self.text_r = Predictor(config)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default="config/vgg-seq2seq.yml", help='config path ') # parser.add_argument('--checkpoint', required=False, help='your checkpoint') args = parser.parse_args() logger = logging.getLogger(__name__) config = Cfg.load_config_from_file(args.config, download_base=False) logger.info("Loaded config from {}".format(args.config)) # print('-- CONFIG --') dataset_params = { 'name': 'hw_word', 'data_root': './DATA', 'is_padding': True, 'image_max_width': 100, 'train_lmdb': [ 'train_hw_word', 'hw_word_9k_good', 'hw_word_50k_dict_3k', 'valid_hw_word', 'hw_word_70k_dict_full_filter' ], 'valid_lmdb': 'test_hw_word' } config['monitor']['log_dir'] = './logs/hw_word_seq2seq_finetuning_240k' trainer_params = { 'batch_size': 32, 'print_every': 200, 'valid_every': 5 * 200, 'iters': 150000, 'metrics': 5000, 'pretrained': './logs/hw_word_seq2seq_finetuning_170k_v2/best.pt', 'resume_from': None, 'is_finetuning': False } config['aug']['masked_language_model'] = False # optim_params = { # 'max_lr': 0.00001 # } # config['optimizer'].update(optim_params) config['trainer'].update(trainer_params) # config['trainer']['resume_from'] = './logs/hw_small_finetuning/last.pt' config['dataset'].update(dataset_params) print(config.pretty_text()) # print(config) trainer = Trainer(config, pretrained=False) # trainer.visualize_dataset() trainer.train()
def load_recognition_model(): #chuan bi ocr predict model config = Cfg.load_config_from_file('./vietocr/config.yml') config['weights'] = "./models/transformerocr.pth" config['cnn']['pretrained']=False config['device'] = 'cuda:0' config['predictor']['beamsearch']=False recognizer = Predictor(config) return recognizer
def main(): parser = argparse.ArgumentParser() parser.add_argument('--img', required=True, help='foo help') parser.add_argument('--config', required=True, help='foo help') args = parser.parse_args() config_base = Cfg.load_config_from_file("config/base.yml") config = Cfg.load_config_from_file(args.config) config_base.update(config) config = config_base config['vocab'] = character detector = Predictor(config) img = Image.open(args.img) s = detector.predict(img) print(s)
def main(): parser = argparse.ArgumentParser() parser.add_argument('config', help='see example at ') parser.add_argument('--checkpoint', help='your checkpoint') args = parser.parse_args() config_base = Cfg.load_config_from_file("config/base.yml") config = Cfg.load_config_from_file(args.config) config_base.update(config) config = config_base config['vocab'] = character trainer = Trainer(config, pretrained=False) # args.checkpoint = config.trainer["checkpoint"] # if args.checkpoint: # trainer.load_checkpoint(args.checkpoint) # logging.info(f"Load checkpoint form {args.checkpoint}....") trainer.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--img', required=True, help='foo help') parser.add_argument('--config', required=True, help='foo help') args = parser.parse_args() config = Cfg.load_config_from_file(args.config) detector = TextDetector(config) img = Image.open(args.img) s = detector.predict(img) print(s)
def predict_file(): config_path = './logs/hw_word_seq2seq/config.yml' config = Cfg.load_config_from_file(config_path, download_base=False) config['weights'] = './logs/hw_word_seq2seq_finetuning/best.pt' print(config.pretty_text()) detector = Predictor(config) detector.gen_annotations( './DATA/data_verifier/hw_word_15k_labels.txt', './DATA/data_verifier/hw_word_15k_labels_preds.txt', data_root='./DATA/data_verifier')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True, help='see example at ') parser.add_argument('--checkpoint', required=False, help='your checkpoint') args = parser.parse_args() config = Cfg.load_config_from_file(args.config) trainer = Trainer(config) if args.checkpoint: trainer.load_checkpoint(args.checkpoint) trainer.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='./logs/hw_word_seq2seq/config.yml') parser.add_argument('--weight', type=str, default='./logs/hw_word_seq2seq/best.pt') parser.add_argument('--img', type=str, default=None, required=True) args = parser.parse_args() config = Cfg.load_config_from_file(args.config, download_base=False) config['weights'] = args.weight print(config.pretty_text()) detector = Predictor(config) if os.path.isdir(args.img): img_paths = os.listdir(args.img) for img_path in img_paths: try: img = Image.open(args.img + '/' + img_path) except: continue t1 = time.time() s, prob = detector.predict(img, return_prob=True) print('Text in {} is:\t {} | prob: {:.2f} | times: {:.2f}'.format( img_path, s, prob, time.time() - t1)) else: t1 = time.time() img = Image.open(args.img) s, prob = detector.predict(img, return_prob=True) print('Text in {} is:\t {} | prob: {:.2f} | times: {:.2f}'.format( args.img, s, prob, time.time() - t1))