Exemplo n.º 1
0
 def __init__(
     self,
     config,
 ) -> None:
     super(ocr, self).__init__()
     self.config = config
     config_base = Cfg.load_config_from_file("config/base.yml")
     config = Cfg.load_config_from_file(self.config)
     config_base.update(config)
     config = config_base
     config['vocab'] = character
     self.text_r = Predictor(config)
Exemplo n.º 2
0
    def __init__(self, weights):
        # list tinh & tp co dau
        self.tinh_list = [
            'An Giang', 'Bà Rịa - Vũng Tàu', 'Bắc Giang', 'Bắc Kạn',
            'Bạc Liêu', 'Bắc Ninh', 'Bến Tre', 'Bình Định', 'Bình Dương',
            'Bình Phước', 'Bình Thuận', 'Cà Mau', 'Cao Bằng', 'Đắk Lắk',
            'Đắk Nông', 'Điện Biên', 'Đồng Nai', 'Đồng Tháp', 'Gia Lai',
            'Hà Giang', 'Hà Nam', 'Hà Tĩnh', 'Hải Dương', 'Hậu Giang',
            'Hòa Bình', 'Hưng Yên', 'Khánh Hòa', 'Kiên Giang', 'Kon Tum',
            'Lai Châu', 'Lâm Đồng', 'Lạng Sơn', 'Lào Cai', 'Long An',
            'Nam Định', 'Nghệ An', 'Ninh Bình', 'Ninh Thuận', 'Phú Thọ',
            'Quảng Bình', 'Quảng Nam', 'Quảng Ngãi', 'Quảng Ninh', 'Quảng Trị',
            'Sóc Trăng', 'Sơn La', 'Tây Ninh', 'Thái Bình', 'Thái Nguyên',
            'Thanh Hóa', 'Thừa Thiên Huế', 'Tiền Giang', 'Trà Vinh',
            'Tuyên Quang', 'Vĩnh Long', 'Vĩnh Phúc', 'Yên Bái', 'Phú Yên',
            'Cần Thơ', 'Đà Nẵng', 'Hải Phòng', 'Hà Nội', 'TP Hồ Chí Minh'
        ]
        # list tinh & tp khong co dau
        self.provinces = [
            self.remove_accent(tinh).lower() for tinh in self.tinh_list
        ]

        self.config = Cfg.load_config_from_name('vgg_transformer')
        self.config['weights'] = weights
        self.config['cnn']['pretrained'] = False
        self.config['device'] = 'cpu'
        self.config['predictor']['beamsearch'] = False

        self.reader = Predictor(self.config)
Exemplo n.º 3
0
    def __init__(self):
        super().__init__()
        manager = Manager()
        self.send = manager.list()
        self.date = manager.list()
        self.quote = manager.list()
        self.number = manager.list()
        self.header = manager.list()
        self.sign = manager.list()
        self.device = torch.device('cpu')
        state_dict = torch.load(
            '/home/dung/Project/Python/ocr/craft_mlt_25k.pth')
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v

        self.craft = CRAFT()
        self.craft.load_state_dict(new_state_dict)
        self.craft.to(self.device)
        self.craft.eval()
        self.craft.share_memory()
        self.config = Cfg.load_config_from_name('vgg_transformer')
        self.config[
            'weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
        self.config['device'] = 'cpu'
        self.config['predictor']['beamsearch'] = False
        self.weights = '/home/dung/Documents/transformerocr.pth'
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default="config/vgg-seq2seq.yml",
                        help='config path ')
    # parser.add_argument('--checkpoint', required=False, help='your checkpoint')

    args = parser.parse_args()
    logger = logging.getLogger(__name__)

    config = Cfg.load_config_from_file(args.config, download_base=False)
    logger.info("Loaded config from {}".format(args.config))
    # print('-- CONFIG --')
    dataset_params = {
        'name':
        'hw_word',
        'data_root':
        './DATA',
        'is_padding':
        True,
        'image_max_width':
        100,
        'train_lmdb': [
            'train_hw_word', 'hw_word_9k_good', 'hw_word_50k_dict_3k',
            'valid_hw_word', 'hw_word_70k_dict_full_filter'
        ],
        'valid_lmdb':
        'test_hw_word'
    }
    config['monitor']['log_dir'] = './logs/hw_word_seq2seq_finetuning_240k'

    trainer_params = {
        'batch_size': 32,
        'print_every': 200,
        'valid_every': 5 * 200,
        'iters': 150000,
        'metrics': 5000,
        'pretrained': './logs/hw_word_seq2seq_finetuning_170k_v2/best.pt',
        'resume_from': None,
        'is_finetuning': False
    }

    config['aug']['masked_language_model'] = False

    # optim_params = {
    #     'max_lr': 0.00001
    # }
    # config['optimizer'].update(optim_params)

    config['trainer'].update(trainer_params)
    # config['trainer']['resume_from'] = './logs/hw_small_finetuning/last.pt'
    config['dataset'].update(dataset_params)

    print(config.pretty_text())
    # print(config)
    trainer = Trainer(config, pretrained=False)
    # trainer.visualize_dataset()
    trainer.train()
Exemplo n.º 5
0
 def __init__(self):
     config = Cfg.load_config_from_name('vgg_transformer')
     config['weights'] = './model/transformerocr.pth'
     # config['weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
     # config['device'] = ''
     config['device'] = 'cuda'
     config['predictor']['beamsearch'] = False
     self.detector = Predictor(config)
Exemplo n.º 6
0
def train_customize():
    # 1.   Load your config
    # 2.   Train model using your dataset above

    # Load the default config, we adopt VGG for image feature extraction

    # * *data_root*: the folder save your all images
    # * *train_annotation*: path to train annotation
    # * *valid_annotation*: path to valid annotation
    # * *print_every*: show train loss at every n steps
    # * *valid_every*: show validation loss at every n steps
    # * *iters*: number of iteration to train your model
    # * *export*: export weights to folder that you can use for inference
    # * *metrics*: number of sample in validation annotation you use for computing full_sequence_accuracy, for large dataset it will take too long, then you can reuduce this number
    #

    config = Cfg.load_config_from_name('vgg_transformer')

    # config['vocab'] = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '

    dataset_params = {
        'name': 'hw',
        'data_root': './data_line/',
        'train_annotation': 'train_line_annotation.txt',
        'valid_annotation': 'test_line_annotation.txt'
    }

    params = {
        'print_every': 200,
        'valid_every': 15 * 200,
        'iters': 20000,
        'checkpoint': './checkpoint/transformerocr_checkpoint.pth',
        'export': './weights/transformerocr.pth',
        'metrics': 10000
    }

    config['trainer'].update(params)
    config['dataset'].update(dataset_params)
    config['device'] = 'cuda:0'

    # you can change any of these params in this full list below
    trainer = Trainer(config, pretrained=True)

    # Save model configuration for inference, load_config_from_file
    trainer.config.save('config.yml')

    # Visualize your dataset to check data augmentation is appropriate
    trainer.visualize_dataset()

    # Train now
    trainer.train()

    # Visualize prediction from our trained model
    trainer.visualize_prediction()

    # Compute full seq accuracy for full valid dataset
    trainer.precision()
Exemplo n.º 7
0
def load_recognition_model():
  #chuan bi ocr predict model
  config = Cfg.load_config_from_file('./vietocr/config.yml')
  config['weights'] = "./models/transformerocr.pth"
  config['cnn']['pretrained']=False
  config['device'] = 'cuda:0'
  config['predictor']['beamsearch']=False
  recognizer = Predictor(config)
  return recognizer
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--img', required=True, help='foo help')
    parser.add_argument('--config', required=True, help='foo help')

    args = parser.parse_args()
    config_base = Cfg.load_config_from_file("config/base.yml")
    config = Cfg.load_config_from_file(args.config)
    config_base.update(config)
    config = config_base

    config['vocab'] = character

    detector = Predictor(config)

    img = Image.open(args.img)
    s = detector.predict(img)

    print(s)
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', help='see example at ')
    parser.add_argument('--checkpoint', help='your checkpoint')

    args = parser.parse_args()
    config_base = Cfg.load_config_from_file("config/base.yml")
    config = Cfg.load_config_from_file(args.config)
    config_base.update(config)
    config = config_base

    config['vocab'] = character
    trainer = Trainer(config, pretrained=False)

    # args.checkpoint = config.trainer["checkpoint"]
    # if args.checkpoint:
    #    trainer.load_checkpoint(args.checkpoint)
    #    logging.info(f"Load checkpoint form {args.checkpoint}....")

    trainer.train()
Exemplo n.º 10
0
    def __init__(self, reg_model='seq2seq'):
        print("Loading TEXT_MODEL...")
        if reg_model == "seq2seq":
            config = Cfg.load_config_from_name('vgg_seq2seq')
            config['weights'] = 'weights/vgg-seq2seq.pth'

        self.model_box = BOX_MODEL()
        config['device'] = 'cpu'
        config['predictor']['beamsearch'] = False

        self.model_reg = Predictor(config)
        self.craft_model = CraftDetection()
Exemplo n.º 11
0
 def __init__(self):
     self.yolo = YOLOv4()
     self.yolo.classes = './coco.names'
     self.yolo.make_model()
     self.yolo.load_weights("./model/yolov4-custom_last.weights",
                            weights_type="yolo")
     self.config = Cfg.load_config()
     self.config['weights'] = './model/transformerocr.pth'
     self.config['predictor']['beamsearch'] = False
     self.config['device'] = 'cpu'
     self.detector = Predictor(self.config)
     self.classes = ['id', 'name', 'dmy', 'add1', 'add2']
     self.res = dict.fromkeys(self.classes, '')
Exemplo n.º 12
0
    def __init__(self, ckpt_path=None, gpu='0'):
        print('Classifier_Vietocr. Init')
        self.config = Cfg.load_config(cls_base_config_path, cls_config_path)

        if ckpt_path is not None:
            self.config['weights'] = ckpt_path
        self.config['cnn']['pretrained'] = False
        if gpu is not None:
            self.config['device'] = 'cuda:' + str(gpu)
        else:
            self.config['device'] = 'cpu'
        self.config['predictor']['beamsearch'] = False
        self.model = Predictor(self.config)
Exemplo n.º 13
0
def create_text_annotation_ocr(imgs, dest):
    config = Cfg.load_config_from_name('vgg_transformer')
    config['export'] = 'transformerocr_checkpoint.pth'
    config['device'] = 'cuda'
    config['predictor']['beamsearch'] = False
    detector = Predictor(config)
    f = io.open(os.path.join(dest, "annotation.txt"), "a", encoding="utf-8")
    for idx, image in enumerate(imgs):
        text = detector.predict(image)
        if idx + 1 == len(imgs):
            f.write('crop_img/{:06d}.jpg\t{}'.format(idx + 1, text))
        else:
            f.write('crop_img/{:06d}.jpg\t{}\n'.format(idx+1, text))
    f.close()
Exemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True, help='see example at ')
    parser.add_argument('--checkpoint', required=False, help='your checkpoint')

    args = parser.parse_args()
    config = Cfg.load_config_from_file(args.config)

    trainer = Trainer(config)

    if args.checkpoint:
        trainer.load_checkpoint(args.checkpoint)

    trainer.train()
Exemplo n.º 15
0
def load_config():
    """
    =========================
    == Reader model
    =========================
    """
    config = Cfg.load_config_from_name('vgg_transformer')
    config['weights'] = './models/reader/transformerocr.pth'
    # config['weights'] = 'https://drive.google.com/uc?export=download&id=1-olev206xLgXYf7rnwHrcZLxxLg5rs0p'
    # config['weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
    config['device'] = 'cpu'
    # self.device = device
    config['predictor']['beamsearch'] = False
    return config
Exemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--img', required=True, help='foo help')
    parser.add_argument('--config', required=True, help='foo help')

    args = parser.parse_args()
    config = Cfg.load_config_from_file(args.config)

    detector = TextDetector(config)

    img = Image.open(args.img)
    s = detector.predict(img)

    print(s)
Exemplo n.º 17
0
def predict_file():
    config_path = './logs/hw_word_seq2seq/config.yml'
    config = Cfg.load_config_from_file(config_path, download_base=False)

    config['weights'] = './logs/hw_word_seq2seq_finetuning/best.pt'

    print(config.pretty_text())

    detector = Predictor(config)

    detector.gen_annotations(
        './DATA/data_verifier/hw_word_15k_labels.txt',
        './DATA/data_verifier/hw_word_15k_labels_preds.txt',
        data_root='./DATA/data_verifier')
Exemplo n.º 18
0
    def __init__(self,
                 ckpt_path=None,
                 gpu='0',
                 config_name='vgg_seq2seq',
                 write_file=False,
                 debug=False):
        print('Classifier_Vietocr. Init')
        self.config = Cfg.load_config_from_name(config_name)

        # config['weights'] = './weights/transformerocr.pth'
        if ckpt_path is not None:
            self.config['weights'] = ckpt_path
        self.config['cnn']['pretrained'] = False
        if gpu is not None:
            self.config['device'] = 'cuda:' + str(gpu)
        self.config['predictor']['beamsearch'] = False
        self.model = Predictor(self.config)
Exemplo n.º 19
0
def load_model():

    config = program.load_config('./configs/det/det_r18_vd_db_v1.1.yml')

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    program.check_gpu(use_gpu)

    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    det_model = create_module(
        config['Architecture']['function'])(params=config)

    startup_prog = fluid.Program()
    eval_prog = fluid.Program()
    with fluid.program_guard(eval_prog, startup_prog):
        with fluid.unique_name.guard():
            _, eval_outputs = det_model(mode="test")
            fetch_name_list = list(eval_outputs.keys())
            eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list]

    eval_prog = eval_prog.clone(for_test=True)
    exe.run(startup_prog)

    # load checkpoints
    checkpoints = config['Global'].get('checkpoints')
    if checkpoints:
        path = checkpoints
        fluid.load(eval_prog, path, exe)
        logger.info("Finish initing model from {}".format(path))
    else:
        raise Exception("{} not exists!".format(checkpoints))

    config_ocr = Cfg.load_config_from_name('vgg_seq2seq')
    config_ocr['weights'] = './my_weights/transformer.pth'
    config_ocr['cnn']['pretrained'] = False
    config_ocr['device'] = 'cpu'
    config_ocr['predictor']['beamsearch'] = False

    detector = Predictor(config_ocr)

    return detector, exe, config, eval_prog, eval_fetch_list
    def __init__(self, model_path):

        #Load the pretrained PhoBERT Model
        print("Loading Classification...")
        self.config = RobertaConfig.from_pretrained(
            model_path + 'PhoBERT/config.json',
            from_tf=False,
            num_labels=5,
            output_hidden_states=False,
        )
        self.phoBERT_cls = RobertaForSequenceClassification.from_pretrained(
            model_path + 'PhoBERT/model.bin', config=self.config)
        device = "cuda:0"
        self.phoBERT_cls = self.phoBERT_cls.to(device)
        self.phoBERT_cls.eval()
        print("Loading pre-trained model...")
        self.phoBERT_cls.load_state_dict(
            torch.load(
                model_path +
                'roberta_state_dict_9bfb8319-01b2-4301-aa5a-756d390a98e1.pth'))
        print("Finished loading PhoBERT Classification model.")

        #Load the BPE and Vocabulary Dictionary
        print("Loading BPE and vocab dict ...")

        class BPE():
            bpe_codes = model_path + 'PhoBERT/bpe.codes'

        args = BPE()
        self.bpe = fastBPE(args)
        self.vocab = Dictionary()
        self.vocab.add_from_file(model_path + "PhoBERT/dict.txt")
        print("Finished loading BPE and vocab dict.")

        #Load the Text Recognizer
        config = Cfg.load_config_from_name('vgg_transformer')
        config['weights'] = 'weights/transformerocr.pth'
        config['cnn']['pretrained'] = False
        config['device'] = 'cuda:0'
        config['predictor']['beamsearch'] = False
        self.text_recognizer = Predictor(config)
Exemplo n.º 21
0
def img_to_text(list_img):
    results = []
    for img in list_img:
        # sử dụng config mặc định của mô hình
        config = Cfg.load_config_from_name('vgg_transformer')
        # đường dẫn đến trọng số đã huấn luyện hoặc comment để sử dụng #pretrained model mặc định
        config['weights'] = 'checkpoints/transformerocr.pth'
        config['device'] = 'cpu'  # device chạy 'cuda:0', 'cuda:1', 'cpu'

        detector = Predictor(config)
        img = Image.fromarray(img.astype(np.uint8))
        # img = Image.fromarray((img * 255).astype(np.uint8))
        # img.show()

        # dự đoán
        # muốn trả về xác suất của câu dự đoán thì đổi return_prob=True
        text = detector.predict(img)

        if len(text) > 0:
            results.append(text)
    return results
Exemplo n.º 22
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='./logs/hw_word_seq2seq/config.yml')
    parser.add_argument('--weight',
                        type=str,
                        default='./logs/hw_word_seq2seq/best.pt')
    parser.add_argument('--img', type=str, default=None, required=True)
    args = parser.parse_args()

    config = Cfg.load_config_from_file(args.config, download_base=False)

    config['weights'] = args.weight

    print(config.pretty_text())

    detector = Predictor(config)
    if os.path.isdir(args.img):
        img_paths = os.listdir(args.img)
        for img_path in img_paths:
            try:
                img = Image.open(args.img + '/' + img_path)
            except:
                continue
            t1 = time.time()
            s, prob = detector.predict(img, return_prob=True)
            print('Text in {} is:\t {} | prob: {:.2f} | times: {:.2f}'.format(
                img_path, s, prob,
                time.time() - t1))
    else:
        t1 = time.time()
        img = Image.open(args.img)
        s, prob = detector.predict(img, return_prob=True)
        print('Text in {} is:\t {} | prob: {:.2f} | times: {:.2f}'.format(
            args.img, s, prob,
            time.time() - t1))
Exemplo n.º 23
0
def img_to_text(list_img):
    results = []
    config = Cfg.load_config_from_name("vgg_transformer")
    # đường dẫn đến trọng số đã huấn luyện hoặc comment để sử dụng #pretrained model mặc định
    config[
        "weights"] = "https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA"
    # config['weights'] = 'transformerocr.pth'
    config["device"] = "cpu"  # device chạy 'cuda:0', 'cuda:1', 'cpu'

    detector = Predictor(config)
    for i in range(len(list_img)):
        if i == 0:
            continue
        # sử dụng config mặc định của mô hình

        img = Image.fromarray(list_img[i].astype(np.uint8))

        # dự đoán
        # muốn trả về xác suất của câu dự đoán thì đổi return_prob=True
        text = detector.predict(img)

        if len(text) > 0:
            results.append(text)
    return results
Exemplo n.º 24
0
 def get_ocr_model(self):
     config = Cfg.load_config_from_name('vgg_seq2seq')
     config['cnn']['pretrained'] = False
     config['device'] = 'cpu'
     config['predictor']['beamsearch'] = False
     return Predictor(config)
Exemplo n.º 25
0
        y_dist_limit: 10 (Maximum distance by y coordinate to merge two boxes)
        x_dist_limit: 40 (Maximum distance by x coordinate to merge two boxes)
        iou_limit = 0.001
        
        '''

        need_merging = True
        while need_merging:
            need_merging, texts, bboxes_xxyy = merge_box_by_iou(
                texts, bboxes_xxyy)

        need_merging = True
        while need_merging:
            need_merging, texts, bboxes_xxyy = merge_box_by_distance(
                texts, bboxes_xxyy)

        return texts


if __name__ == "__main__":
    config = Cfg.load_config_from_name('vgg_transformer')
    config['weights'] = 'weights/transformerocr.pth'
    config['cnn']['pretrained'] = False
    config['device'] = 'cuda:0'
    config['predictor']['beamsearch'] = False
    text_recognizer = Predictor(config)
    test_image_path = "test_data/Công văn 641_UBND-NC PDF.pdf.jpg"
    image = cv2.imread(test_image_path)
    detected_texts = export_text(image, text_recognizer)
    print(detected_texts)
Exemplo n.º 26
0
import torch
from torch.autograd import Variable
import glob
import os
import csv
import cv2
import numpy as np
import argparse

from PIL import Image
import time

from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg

config = Cfg.load_config_from_name('vgg_seq2seq')

# config['weights'] = './transformerocr.pth'
# config['weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
config['device'] = 'cuda:0'
config['predictor']['beamsearch'] = False
# config['trainer']['checkpoint'] = '/dataset/Students/thuyentd/VietOcr/vgg_seq2seq_receipt_31122020checkpoint.pth'

detector = Predictor(config)


def predict_this_box(config, img):
    s, pros = detector.predict(img, return_prob=True)
    return s, pros

Exemplo n.º 27
0
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
import Levenshtein


def cer_loss_one_image(sim_pred, label):
    if (max(len(sim_pred), len(label)) > 0):
        loss = Levenshtein.distance(sim_pred, label) * 1.0 / max(len(sim_pred), len(label))
    else:
        return 0
    return loss

debug = False
eval = True
config = Cfg.load_config_from_name('vgg_seq2seq') #vgg_transformer, vgg_seq2seq, vgg_convseq2seq

# config['weights'] = './weights/transformerocr.pth'
# config['weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
config['cnn']['pretrained'] = False
config['device'] = 'cuda:0'
config['predictor']['beamsearch'] = False
# config['dataset']['image_min_width'] = 64
# config['dataset']['image_max_width'] = 256
detector = Predictor(config)
src_dir = '/data20.04/data/data_Korea/Korea_test_Vietnamese_1106/crop'
#src_dir = '/data20.04/data/data_Korea/Cello_Vietnamese/crnn_extend_True_y_ratio_0.05_min_y_4_min_x_2'

img_path = 'test_hw.jpg'
#img_path = ''
if img_path == '':
Exemplo n.º 28
0
import requests
import pdb

# import colabcode
from colabcode import ColabCode

# import mmdetection
from mmdet.apis import init_detector, inference_detector

# import vietocr
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg

# import model vietocr

config_seller = Cfg.load_config_from_name('vgg_transformer')
config_seller['weights'] = './models/OCR/seller.pth'
config_seller['cnn']['pretrained'] = False
config_seller['device'] = 'cuda:0'
config_seller['predictor']['beamsearch'] = False
detector_seller = Predictor(config_seller)

config_address = Cfg.load_config_from_name('vgg_transformer')
config_address['weights'] = './models/OCR/address.pth'
config_address['cnn']['pretrained'] = False
config_address['device'] = 'cuda:0'
config_address['predictor']['beamsearch'] = False
detector_address = Predictor(config_address)

config_timestamp = Cfg.load_config_from_name('vgg_transformer')
config_timestamp['weights'] = './models/OCR/timestamp.pth'
Exemplo n.º 29
0
    def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.train_lmdb = config['dataset']['train_lmdb']
        self.valid_lmdb = config['dataset']['valid_lmdb']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.metrics = config['trainer']['metrics']
        self.is_padding = config['dataset']['is_padding']

        self.tensorboard_dir = config['monitor']['log_dir']
        if not os.path.exists(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir, exist_ok=True)
        self.writer = SummaryWriter(self.tensorboard_dir)

        # LOGGER
        self.logger = Logger(config['monitor']['log_dir'])
        self.logger.info(config)

        self.iter = 0
        self.best_acc = 0
        self.scheduler = None
        self.is_finetuning = config['trainer']['is_finetuning']

        if self.is_finetuning:
            self.logger.info("Finetuning model ---->")
            if self.model.seq_modeling == 'crnn':
                self.optimizer = Adam(lr=0.0001,
                                      params=self.model.parameters(),
                                      betas=(0.5, 0.999))
            else:
                self.optimizer = AdamW(lr=0.0001,
                                       params=self.model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-09)

        else:

            self.optimizer = AdamW(self.model.parameters(),
                                   betas=(0.9, 0.98),
                                   eps=1e-09)
            self.scheduler = OneCycleLR(self.optimizer,
                                        total_steps=self.num_iters,
                                        **config['optimizer'])

        if self.model.seq_modeling == 'crnn':
            self.criterion = torch.nn.CTCLoss(self.vocab.pad,
                                              zero_infinity=True)
        else:
            self.criterion = LabelSmoothingLoss(len(self.vocab),
                                                padding_idx=self.vocab.pad,
                                                smoothing=0.1)

        # Pretrained model
        if config['trainer']['pretrained']:
            self.load_weights(config['trainer']['pretrained'])
            self.logger.info("Loaded trained model from: {}".format(
                config['trainer']['pretrained']))

        # Resume
        elif config['trainer']['resume_from']:
            self.load_checkpoint(config['trainer']['resume_from'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(torch.device(self.device))

            self.logger.info("Resume training from {}".format(
                config['trainer']['resume_from']))

        # DATASET
        transforms = None
        if self.image_aug:
            transforms = augmentor

        train_lmdb_paths = [
            os.path.join(self.data_root, lmdb_path)
            for lmdb_path in self.train_lmdb
        ]

        self.train_gen = self.data_gen(
            lmdb_paths=train_lmdb_paths,
            data_root=self.data_root,
            annotation=self.train_annotation,
            masked_language_model=self.masked_language_model,
            transform=transforms,
            is_train=True)

        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)],
                data_root=self.data_root,
                annotation=self.valid_annotation,
                masked_language_model=False)

        self.train_losses = []
        self.logger.info("Number batch samples of training: %d" %
                         len(self.train_gen))
        self.logger.info("Number batch samples of valid: %d" %
                         len(self.valid_gen))

        config_savepath = os.path.join(self.tensorboard_dir, "config.yml")
        if not os.path.exists(config_savepath):
            self.logger.info("Saving config file at: %s" % config_savepath)
            Cfg(config).save(config_savepath)