コード例 #1
0
ファイル: trainer.py プロジェクト: tranquansp/vietocr
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            weight_file = download_weights(**config['pretrain'],
                                           quiet=config['quiet'])
            self.load_weights(weight_file)

        self.iter = 0

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer, **config['optimizer'])
        #        self.optimizer = ScheduledOptim(
        #            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        #            #config['transformer']['d_model'],
        #            512,
        #            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = ImgAugTransform()

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                'valid_{}'.format(self.dataset_name), self.data_root,
                self.valid_annotation)

        self.train_losses = []
コード例 #2
0
ファイル: trainer.py プロジェクト: hisiter97/vietocr_hst
    def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.train_lmdb = config['dataset']['train_lmdb']
        self.valid_lmdb = config['dataset']['valid_lmdb']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.metrics = config['trainer']['metrics']
        self.is_padding = config['dataset']['is_padding']

        self.tensorboard_dir = config['monitor']['log_dir']
        if not os.path.exists(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir, exist_ok=True)
        self.writer = SummaryWriter(self.tensorboard_dir)

        # LOGGER
        self.logger = Logger(config['monitor']['log_dir'])
        self.logger.info(config)

        self.iter = 0
        self.best_acc = 0
        self.scheduler = None
        self.is_finetuning = config['trainer']['is_finetuning']

        if self.is_finetuning:
            self.logger.info("Finetuning model ---->")
            if self.model.seq_modeling == 'crnn':
                self.optimizer = Adam(lr=0.0001,
                                      params=self.model.parameters(),
                                      betas=(0.5, 0.999))
            else:
                self.optimizer = AdamW(lr=0.0001,
                                       params=self.model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-09)

        else:

            self.optimizer = AdamW(self.model.parameters(),
                                   betas=(0.9, 0.98),
                                   eps=1e-09)
            self.scheduler = OneCycleLR(self.optimizer,
                                        total_steps=self.num_iters,
                                        **config['optimizer'])

        if self.model.seq_modeling == 'crnn':
            self.criterion = torch.nn.CTCLoss(self.vocab.pad,
                                              zero_infinity=True)
        else:
            self.criterion = LabelSmoothingLoss(len(self.vocab),
                                                padding_idx=self.vocab.pad,
                                                smoothing=0.1)

        # Pretrained model
        if config['trainer']['pretrained']:
            self.load_weights(config['trainer']['pretrained'])
            self.logger.info("Loaded trained model from: {}".format(
                config['trainer']['pretrained']))

        # Resume
        elif config['trainer']['resume_from']:
            self.load_checkpoint(config['trainer']['resume_from'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(torch.device(self.device))

            self.logger.info("Resume training from {}".format(
                config['trainer']['resume_from']))

        # DATASET
        transforms = None
        if self.image_aug:
            transforms = augmentor

        train_lmdb_paths = [
            os.path.join(self.data_root, lmdb_path)
            for lmdb_path in self.train_lmdb
        ]

        self.train_gen = self.data_gen(
            lmdb_paths=train_lmdb_paths,
            data_root=self.data_root,
            annotation=self.train_annotation,
            masked_language_model=self.masked_language_model,
            transform=transforms,
            is_train=True)

        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)],
                data_root=self.data_root,
                annotation=self.valid_annotation,
                masked_language_model=False)

        self.train_losses = []
        self.logger.info("Number batch samples of training: %d" %
                         len(self.train_gen))
        self.logger.info("Number batch samples of valid: %d" %
                         len(self.valid_gen))

        config_savepath = os.path.join(self.tensorboard_dir, "config.yml")
        if not os.path.exists(config_savepath):
            self.logger.info("Saving config file at: %s" % config_savepath)
            Cfg(config).save(config_savepath)
コード例 #3
0
ファイル: trainer.py プロジェクト: linhduongtuan/vietocr
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            download_weights(**config['pretrain'], quiet=config['quiet'])
            state_dict = torch.load(config['pretrain']['cached'],
                                    map_location=torch.device(self.device))

            for name, param in self.model.named_parameters():
                if state_dict[name].shape != param.shape:
                    print('{} missmatching shape'.format(name))
                    del state_dict[name]

            self.model.load_state_dict(state_dict, strict=False)

        self.iter = 0

        self.optimizer = ScheduledOptim(
            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            config['transformer']['d_model'], **config['optimizer'])

        #        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        #        transforms = torchvision.transforms.Compose([
        #            torchvision.transforms.ColorJitter(brightness=.1, contrast=.1, hue=.1, saturation=.1),
        #            torchvision.transforms.RandomAffine(degrees=0, scale=(3/4, 4/3))
        #            ])

        transforms = ImgAugTransform()

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                'valid_{}'.format(self.dataset_name), self.data_root,
                self.valid_annotation)

        self.train_losses = []
コード例 #4
0
ファイル: trainer.py プロジェクト: lzmisscc/vietocr
    def __init__(self, config, pretrained=False, augmentor=ImgAugTransform()):
        self.config = config
        self.model, self.vocab = build_model(config)
        # ipdb.set_trace()
        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']

        # self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            logging.info(f"Load weights from {config['weights']}....")
            # weight_file = download_weights(**config['pretrain'], quiet=config['quiet'])
            self.load_weights(config['weights'])

        self.iter = 0

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    **config['optimizer'])
        #        self.optimizer = ScheduledOptim(
        #            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        #            #config['transformer']['d_model'],
        #            512,
        #            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = None
        if self.image_aug:
            transforms = augmentor

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       self.masked_language_model,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen('valid_{}'.format(
                self.dataset_name),
                                           self.data_root,
                                           self.valid_annotation,
                                           masked_language_model=False)

        self.train_losses = []