def main(model_name, checkpoint_name, url): config = Cfg.load_config_from_name(model_name) dataset_params = { 'name': 'hw', 'data_root': '/home/longhn', # 'train_root': '/home/fdm/Desktop/chungnph/vietocr/Annotation_2505', # 'val_root': '/home/fdm/Desktop/chungnph/vietocr/Annotation_2505', 'train_annotation': f'{url}/train.txt', 'valid_annotation': f'{url}/valid.txt' } params = { 'print_every': 200, 'valid_every': 10 * 200, 'iters': 30000, 'checkpoint': f'./checkpoint/{checkpoint_name}.pth', 'export': f'./checkpoint/{checkpoint_name}.pth', 'metrics': 15000, 'batch_size': 32 } dataloader_params = {'num_workers': 1} # config['pretrain']['cached'] = 'checkpoint/ngaycap_0204.pth' config['trainer'].update(params) config['dataset'].update(dataset_params) config['dataloader'].update(dataloader_params) config['device'] = 'cuda' config[ 'vocab'] = '''aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ ''' # config['weights'] = 'checkpoint/ngaycap_0204.pth' print(config) trainer = Trainer(config, pretrained=True) trainer.config.save(f'train_config/{checkpoint_name}.yml') trainer.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True, help='see example at ') parser.add_argument('--checkpoint', required=False, help='your checkpoint') args = parser.parse_args() config = Cfg.load_config_from_file(args.config) trainer = Trainer(config) if args.checkpoint: trainer.load_checkpoint(args.checkpoint) trainer.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--img', required=True, help='foo help') parser.add_argument('--config', required=True, help='foo help') args = parser.parse_args() config = Cfg.load_config_from_file(args.config) detector = Predictor(config) img = Image.open(args.img) s = detector.predict(img) print(s)
Created on Fri May 21 09:07:18 2021 @author: fdm """ import cv2 import numpy as np import glob import ntpath import os import re from PIL import Image import time from tool.predictor import Predictor from tool.config import Cfg config_all = Cfg.load_config_from_file( './train_config/seq2seq_handwriting_0307_pretrain_32_2k.yml') config_all[ 'weights'] = './checkpoint/seq2seq_handwriting_0307_pretrain_32_2k.pth' config_all['cnn']['pretrained'] = False config_all['device'] = 'cuda:0' # config_all['device'] = 'cuda:1' config_all['predictor']['beamsearch'] = False # config_all['vocab'] = '''aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0125456789!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ ''' detector_old = Predictor(config_all) def no_accent_vietnamese(s): s = re.sub(r'[àáạảãâầấậẩẫăằắặẳẵ]', 'a', s) s = re.sub(r'[ÀÁẠẢÃĂẰẮẶẲẴÂẦẤẬẨẪ]', 'A', s) s = re.sub(r'[èéẹẻẽêềếệểễ]', 'e', s) s = re.sub(r'[ÈÉẸẺẼÊỀẾỆỂỄ]', 'E', s)
import argparse from PIL import Image import glob from tool.predictor import Predictor from model.trainer import Trainer from tool.config import Cfg import cv2 # config = Cfg.load_config_from_file('config/vgg_transformer.yml') config = Cfg.load_config_from_file('config_seq2seq.yml') dataset_params = { 'name': 'hw', 'data_root': '/home/longhn', 'train_annotation': '/home/longhn/Anotation/train.txt', 'valid_annotation': '/home/longhn/Anotation/valid.txt' } params = { 'print_every': 200, 'valid_every': 10 * 200, 'iters': 200000, 'checkpoint': 'checkpoint/smartdoc_seq2seq.pth', 'export': 'checkpoint/smartdoc_seq2seq.pth', 'metrics': 10000000, 'batch_size': 32 } dataloader_params = {'num_workers': 1} config['trainer'].update(params) config['dataset'].update(dataset_params)
#coding: utf-8 import uvicorn import os from os import path from fastapi import FastAPI, File, UploadFile from PIL import Image from tool.predictor import Predictor from tool.config import Cfg os.environ["CUDA_VISIBLE_DEVICES"] = "-1" app = FastAPI() config = Cfg.load_config_from_file("config/vgg_transformer.yml") config[ 'weights'] = f'{os.getenv("MODEL_PATH", "../../models")}/vietocr/transformerocr.pth' config['cnn']['pretrained'] = False config['device'] = 'cpu' detector = Predictor(config) @app.post('/recognize') async def recognize(file: UploadFile = File(...)): result = detector.predict_bytes(await file.read()) return {'result': result} if __name__ == "__main__": import sys sys.path.append(path.join(path.dirname(__file__), '..')) uvicorn.run(app, host="0.0.0.0", port=8002)