Пример #1
0
    def __init__(self, config, pretrain=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_epochs = config['trainer']['epochs']
        self.data_root = config['trainer']['data_root']
        self.train_annotation = config['trainer']['train_annotation']
        self.valid_annotation = config['trainer']['valid_annotation']
        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']
        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']

        if pretrain:
            download_weights(**config['pretrain'], quiet=config['quiet'])
            self.model.load_state_dict(torch.load(config['pretrain']['cached'], map_location=torch.device(self.device)))

        self.epoch = 0 
        self.iter = 0

        self.optimizer = ScheduledOptim(
            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            0.2, config['transformer']['d_model'], config['optimizer']['n_warmup_steps'])

        self.criterion = nn.CrossEntropyLoss(ignore_index=0) 

        self.train_gen = DataGen(self.data_root, self.train_annotation, self.vocab, self.device)
        if self.valid_annotation:
            self.valid_gen = DataGen(self.data_root, self.valid_annotation, self.vocab, self.device)
        
        self.train_losses = []
Пример #2
0
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            weight_file = download_weights(**config['pretrain'],
                                           quiet=config['quiet'])
            self.load_weights(weight_file)

        self.iter = 0

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer, **config['optimizer'])
        #        self.optimizer = ScheduledOptim(
        #            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        #            #config['transformer']['d_model'],
        #            512,
        #            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = ImgAugTransform()

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                'valid_{}'.format(self.dataset_name), self.data_root,
                self.valid_annotation)

        self.train_losses = []
Пример #3
0
    def process(self, craft, seq, key, sub_img):
        img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
            sub_img, 2560, interpolation=cv2.INTER_LINEAR, mag_ratio=1.)
        ratio_h = ratio_w = 1 / target_ratio

        x = normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = x.unsqueeze(0)  # [c, h, w] to [b, c, h, w]
        x = x.to(self.device)
        y, feature = craft(x)
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()
        boxes, polys = getDetBoxes(score_text,
                                   score_link,
                                   text_threshold=0.7,
                                   link_threshold=0.4,
                                   low_text=0.4,
                                   poly=False)
        boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None:
                polys[k] = boxes[k]
        result = []
        for i, box in enumerate(polys):
            poly = np.array(box).astype(np.int32).reshape((-1))
            result.append(poly)
        horizontal_list, free_list = group_text_box(result,
                                                    slope_ths=0.8,
                                                    ycenter_ths=0.5,
                                                    height_ths=1,
                                                    width_ths=1,
                                                    add_margin=0.1)
        # horizontal_list = [i for i in horizontal_list if i[0] > 0 and i[1] > 0]
        min_size = 20
        if min_size:
            horizontal_list = [
                i for i in horizontal_list
                if max(i[1] - i[0], i[3] - i[2]) > 10
            ]
            free_list = [
                i for i in free_list
                if max(diff([c[0] for c in i]), diff([c[1]
                                                      for c in i])) > min_size
            ]
        seq[:] = [None] * len(horizontal_list)
        model, vocab = build_model(self.config)
        model.load_state_dict(
            torch.load(self.weights, map_location=torch.device('cpu')))

        for i, ele in enumerate(horizontal_list):
            ele = [0 if i < 0 else i for i in ele]
            img = sub_img[ele[2]:ele[3], ele[0]:ele[1], :]
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img.astype(np.uint8))
            p = threading.Thread(target=self.predict,
                                 args=(model, vocab, seq, key, i, img))
            p.start()
            p.join()
Пример #4
0
 def __init__(self, config):
     device = config['device']
     model, vocab = build_model(config)
     weights = config['weights']
     #
     if not os.path.exists(weights):
         download_model("reader")
     model.load_state_dict(torch.load(weights, map_location=torch.device(device)))
     #
     self.config = config
     self.model = model
     self.vocab = vocab
Пример #5
0
    def __init__(self, config):
        device = config['device']

        model, vocab = build_model(config)
        weights = config['weights']

        model.load_state_dict(
            torch.load(weights, map_location=torch.device(device)))

        self.config = config
        self.model = model
        self.vocab = vocab
Пример #6
0
    def __init__(self, config):
        device = config['device']

        model, vocab = build_model(config)
        weights = './models/reader/transformerocr.pth'

        if config['weights'].startswith('http'):
            weights = download_weights(config['weights'])
        else:
            weights = config['weights']

        model.load_state_dict(
            torch.load(weights, map_location=torch.device(device)))

        self.config = config
        self.model = model
        self.vocab = vocab
Пример #7
0
    def __init__(self, config, quanti=False):

        device = config['device']

        model, vocab = build_model(config)
        weights = '/tmp/weights.pth'

        if config['weights'].startswith('http'):
            weights = download_weights(config['weights'])
        else:
            weights = config['weights']

        model.load_state_dict(
            torch.load(weights, map_location=torch.device(device)))

        self.config = config
        self.model = model
        self.vocab = vocab
Пример #8
0
    def __init__(self, config):

        device = config['device']

        model, vocab = build_model(config)

        if config['weights'].startswith('http'):
            weights = download_weights(config['weights'])
        else:
            weights = config['weights']

        try:
            model.load_state_dict(
                torch.load(weights,
                           map_location=torch.device(device))['state_dict'])
        except:
            model.load_state_dict(
                torch.load(weights, map_location=torch.device(device)))

        self.config = config
        self.model = model
        self.vocab = vocab
Пример #9
0
    def __init__(self):
        super().__init__()
        manager = Manager()
        self.send = manager.list()
        self.date = manager.list()
        self.quote = manager.list()
        self.number = manager.list()
        self.header = manager.list()
        self.sign = manager.list()
        self.device = torch.device('cpu')
        state_dict = torch.load(
            '/home/dung/Project/Python/ocr/craft_mlt_25k.pth')
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v

        self.craft = CRAFT()
        self.craft.load_state_dict(new_state_dict)
        self.craft.to(self.device)
        self.craft.eval()
        self.craft.share_memory()
        self.config = Cfg.load_config_from_name('vgg_transformer')
        self.config[
            'weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
        self.config['device'] = 'cpu'
        self.config['predictor']['beamsearch'] = False
        self.weights = '/home/dung/Documents/transformerocr.pth'
        self.model, self.vocab = build_model(self.config)
        self.model.load_state_dict(torch.load(self.weights))
        self.model.to(self.device)
        self.model.eval()
        self.model.share_memory()
Пример #10
0
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            download_weights(**config['pretrain'], quiet=config['quiet'])
            state_dict = torch.load(config['pretrain']['cached'],
                                    map_location=torch.device(self.device))

            for name, param in self.model.named_parameters():
                if state_dict[name].shape != param.shape:
                    print('{} missmatching shape'.format(name))
                    del state_dict[name]

            self.model.load_state_dict(state_dict, strict=False)

        self.iter = 0

        self.optimizer = ScheduledOptim(
            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            config['transformer']['d_model'], **config['optimizer'])

        #        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ColorJitter(brightness=.1,
                                               contrast=.1,
                                               hue=.1,
                                               saturation=.1),
            torchvision.transforms.RandomAffine(degrees=0,
                                                scale=(3 / 4, 4 / 3))
        ])

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                'valid_{}'.format(self.dataset_name), self.data_root,
                self.valid_annotation)

        self.train_losses = []
Пример #11
0
    def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.train_lmdb = config['dataset']['train_lmdb']
        self.valid_lmdb = config['dataset']['valid_lmdb']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.metrics = config['trainer']['metrics']
        self.is_padding = config['dataset']['is_padding']

        self.tensorboard_dir = config['monitor']['log_dir']
        if not os.path.exists(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir, exist_ok=True)
        self.writer = SummaryWriter(self.tensorboard_dir)

        # LOGGER
        self.logger = Logger(config['monitor']['log_dir'])
        self.logger.info(config)

        self.iter = 0
        self.best_acc = 0
        self.scheduler = None
        self.is_finetuning = config['trainer']['is_finetuning']

        if self.is_finetuning:
            self.logger.info("Finetuning model ---->")
            if self.model.seq_modeling == 'crnn':
                self.optimizer = Adam(lr=0.0001,
                                      params=self.model.parameters(),
                                      betas=(0.5, 0.999))
            else:
                self.optimizer = AdamW(lr=0.0001,
                                       params=self.model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-09)

        else:

            self.optimizer = AdamW(self.model.parameters(),
                                   betas=(0.9, 0.98),
                                   eps=1e-09)
            self.scheduler = OneCycleLR(self.optimizer,
                                        total_steps=self.num_iters,
                                        **config['optimizer'])

        if self.model.seq_modeling == 'crnn':
            self.criterion = torch.nn.CTCLoss(self.vocab.pad,
                                              zero_infinity=True)
        else:
            self.criterion = LabelSmoothingLoss(len(self.vocab),
                                                padding_idx=self.vocab.pad,
                                                smoothing=0.1)

        # Pretrained model
        if config['trainer']['pretrained']:
            self.load_weights(config['trainer']['pretrained'])
            self.logger.info("Loaded trained model from: {}".format(
                config['trainer']['pretrained']))

        # Resume
        elif config['trainer']['resume_from']:
            self.load_checkpoint(config['trainer']['resume_from'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(torch.device(self.device))

            self.logger.info("Resume training from {}".format(
                config['trainer']['resume_from']))

        # DATASET
        transforms = None
        if self.image_aug:
            transforms = augmentor

        train_lmdb_paths = [
            os.path.join(self.data_root, lmdb_path)
            for lmdb_path in self.train_lmdb
        ]

        self.train_gen = self.data_gen(
            lmdb_paths=train_lmdb_paths,
            data_root=self.data_root,
            annotation=self.train_annotation,
            masked_language_model=self.masked_language_model,
            transform=transforms,
            is_train=True)

        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)],
                data_root=self.data_root,
                annotation=self.valid_annotation,
                masked_language_model=False)

        self.train_losses = []
        self.logger.info("Number batch samples of training: %d" %
                         len(self.train_gen))
        self.logger.info("Number batch samples of valid: %d" %
                         len(self.valid_gen))

        config_savepath = os.path.join(self.tensorboard_dir, "config.yml")
        if not os.path.exists(config_savepath):
            self.logger.info("Saving config file at: %s" % config_savepath)
            Cfg(config).save(config_savepath)
Пример #12
0
from vietocr.tool.translate import build_model, translate_beam_search, process_input, predict
from vietocr.tool.config import Cfg
import cv2
from PIL import Image
import numpy as np
import torch
config = Cfg.load_config_from_name('vgg_transformer')
config[
    'weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
config['device'] = 'cpu'
config['predictor']['beamsearch'] = False
model, vocab = build_model(config)
weights = 'transformerocr.pth'
device = torch.device('cpu')
# if config['weights'].startswith('http'):
#     weights = download_weights(config['weights'])
# else:
#     weights = config['weights']
model.load_state_dict(torch.load(weights, map_location=torch.device('cpu')))
sub_img = cv2.imread('5.png')
# cv2.imshow('aa',sub_img)
# cv2.waitKey(0)
img = cv2.cvtColor(sub_img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img.astype(np.uint8))
img = process_input(img, config['dataset']['image_height'],
                    config['dataset']['image_min_width'],
                    config['dataset']['image_max_width'])
img = img.to(config['device'])
with torch.no_grad():
    src = model.cnn(img)
    memory = model.transformer.forward_encoder(src)
Пример #13
0
    def __init__(self, config, pretrained=False, augmentor=ImgAugTransform()):
        self.config = config
        self.model, self.vocab = build_model(config)
        # ipdb.set_trace()
        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']

        # self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            logging.info(f"Load weights from {config['weights']}....")
            # weight_file = download_weights(**config['pretrain'], quiet=config['quiet'])
            self.load_weights(config['weights'])

        self.iter = 0

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    **config['optimizer'])
        #        self.optimizer = ScheduledOptim(
        #            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        #            #config['transformer']['d_model'],
        #            512,
        #            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = None
        if self.image_aug:
            transforms = augmentor

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       self.masked_language_model,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen('valid_{}'.format(
                self.dataset_name),
                                           self.data_root,
                                           self.valid_annotation,
                                           masked_language_model=False)

        self.train_losses = []