コード例 #1
0
 def __init__(self, work_dir, config, rank=0):
     self.rank = rank
     self.config = config
     if self.rank == 0:
         seed_worker(self.config.get('seed'))
     self.model_path = config.get('model_path')
     self.train_paths = config.get('train_paths')
     self.val_path = config.get('val_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.train_transforms_list = config.get('train_transforms_list')
     self.val_transforms_list = config.get('val_transforms_list')
     self.loss_type = config.get('loss_type', 'NLL')
     self.total_epochs = config.get('epochs', 30)
     self.learing_rate = config.get('learning_rate', 1e-3)
     self.clip = config.get('clip_grad', 5.0)
     self.work_dir = os.path.abspath(work_dir)
     self.save_dir = os.path.join(self.work_dir, config.get('save_dir', 'model_checkpoints'))
     self.val_results_path = os.path.join(self.work_dir, 'val_results')
     self.step = 0
     self.global_step = 0
     self._test_steps = config.get('_test_steps', 1e18)
     self.epoch = 1
     self.best_val_loss = 1e18
     self.best_val_accuracy = 0.0
     self.best_val_loss_test = 1e18
     self.best_val_accuracy_test = 0.0
     self.print_freq = config.get('print_freq', 16)
     self.save_freq = config.get('save_freq', 2000)
     self.val_freq = config.get('val_freq', 5000)
     self.logs_path = os.path.join(self.work_dir, config.get('log_path', 'logs'))
     if self.rank == 0:
         self.writer = SummaryWriter(self.logs_path)
         self.writer.add_text('General info', pformat(config))
     self.device = config.get('device', 'cpu')
     self.multi_gpu = config.get('multi_gpu')
     if self.multi_gpu:
         torch.distributed.init_process_group("nccl", init_method="env://")
         self.device = torch.device(f'cuda:{self.rank}')
     self.create_dirs()
     self.load_dataset()
     self.loss = torch.nn.CTCLoss(blank=0, zero_infinity=self.config.get(
         'CTCLossZeroInf', False)) if self.loss_type == 'CTC' else None
     self.out_size = len(self.vocab) + 1 if self.loss_type == 'CTC' else len(self.vocab)
     self.model = TextRecognitionModel(config.get('backbone_config'), self.out_size,
                                       config.get('head', {}), config.get('transformation', {}))
     print(self.model)
     if self.model_path is not None:
         self.model.load_weights(self.model_path, map_location=self.device)
     self.model = self.model.to(self.device)
     if self.multi_gpu:
         if torch.cuda.device_count() > 1:
             self.model = torch.nn.parallel.DistributedDataParallel(
                 self.model, device_ids=[self.rank], output_device=self.rank)
     self.optimizer = getattr(optim, config.get('optimizer', 'Adam'))(self.model.parameters(), self.learing_rate)
     self.lr_scheduler = getattr(optim.lr_scheduler, self.config.get('scheduler', 'ReduceLROnPlateau'))(
         self.optimizer, **self.config.get('scheduler_params', {}))
     self.time = get_timestamp()
     self.use_lang_model = self.config.get("head").get("use_semantics")
     if self.use_lang_model:
         self.fasttext_model = fasttext.load_model(self.config.get("language_model_path"))
コード例 #2
0
 def __init__(self, config, runner_type=RunnerType.PyTorch):
     self.config = deepcopy(config)
     self.runner = create_runner(self.config, runner_type)
     self.vocab = read_vocab(self.config.get('vocab_path'))
     self.render = self.config.get('render')
     self.load_dataset()
     self.runner.load_model()
     self.read_expected_outputs()
コード例 #3
0
 def __init__(self, work_dir, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.train_paths = config.get('train_paths')
     self.val_path = config.get('val_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.train_transforms_list = config.get('train_transforms_list')
     self.val_transforms_list = config.get('val_transforms_list')
     self.loss_type = config.get('loss_type', 'NLL')
     self.total_epochs = config.get('epochs', 30)
     self.learing_rate = config.get('learning_rate', 1e-3)
     self.clip = config.get('clip_grad', 5.0)
     self.work_dir = os.path.abspath(work_dir)
     self.save_dir = os.path.join(
         self.work_dir, config.get('save_dir', 'model_checkpoints'))
     self.val_results_path = os.path.join(self.work_dir, 'val_results')
     self.step = 0
     self.global_step = 0
     self._test_steps = config.get('_test_steps', 1e18)
     self.epoch = 1
     self.best_val_loss = 1e18
     self.best_val_accuracy = 0.0
     self.best_val_loss_test = 1e18
     self.best_val_accuracy_test = 0.0
     self.print_freq = config.get('print_freq', 16)
     self.save_freq = config.get('save_freq', 2000)
     self.val_freq = config.get('val_freq', 5000)
     self.logs_path = os.path.join(self.work_dir,
                                   config.get('log_path', 'logs'))
     self.writer = SummaryWriter(self.logs_path)
     self.device = config.get('device', 'cpu')
     self.writer.add_text('General info', pformat(config))
     self.create_dirs()
     self.load_dataset()
     self.loss = torch.nn.CTCLoss(
         blank=0, zero_infinity=self.config.get(
             'CTCLossZeroInf', False)) if self.loss_type == 'CTC' else None
     self.out_size = len(
         self.vocab) + 1 if self.loss_type == 'CTC' else len(self.vocab)
     self.model = TextRecognitionModel(config.get('backbone_config'),
                                       self.out_size,
                                       config.get('head', {}))
     print(self.model)
     if self.model_path is not None:
         self.model.load_weights(self.model_path, map_location=self.device)
     self.model = self.model.to(self.device)
     self.optimizer = getattr(optim,
                              config.get('optimizer',
                                         'Adam'))(self.model.parameters(),
                                                  self.learing_rate)
     self.lr_scheduler = getattr(
         optim.lr_scheduler,
         self.config.get('scheduler', 'ReduceLROnPlateau'))(
             self.optimizer, **self.config.get('scheduler_params', {}))
     self.time = get_timestamp()
コード例 #4
0
 def load_model(self):
     self.vocab_len = len(read_vocab(self.config.get('vocab_path')))
     self.use_ctc = self.config.get('use_ctc')
     out_size = self.vocab_len + 1 if self.use_ctc else self.vocab_len
     self.model = TextRecognitionModel(
         self.config.get('backbone_config'), out_size,
         self.config.get('head', {}), self.config.get('transformation', {}))
     self.device = self.config.get('device', 'cpu')
     self.model.load_weights(self.config.get('model_path'),
                             map_location=self.device)
     self.model = self.model.to(self.device)
     self.model.eval()
コード例 #5
0
 def __init__(self, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.transform = create_list_of_transforms(config.get('transforms_list'))
     self.use_ctc = self.config.get('use_ctc')
     self.model = TextRecognitionModel(config.get('backbone_config'), len(
         self.vocab), config.get('head', {}), config.get('transformation', {}))
     if self.model_path is not None:
         self.model.load_weights(self.model_path, map_location=config.get('map_location', 'cpu'))
     self.model.eval()
     self.device = config.get('device', 'cpu')
     self.model = self.model.to(self.device)
コード例 #6
0
 def __init__(self, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.use_ctc = self.config.get('use_ctc')
     self.out_size = len(self.vocab) + 1 if self.use_ctc else len(self.vocab)
     self.model = TextRecognitionModel(config.get('backbone_config'), self.out_size, config.get('head', {}))
     self.model.eval()
     if self.model_path is not None:
         self.model.load_weights(self.model_path)
     self.img_for_export = torch.rand(self.config.get('input_shape_encoder', self.config.get('input_shape')))
     if not self.use_ctc:
         self.encoder = self.model.get_encoder_wrapper(self.model)
         self.encoder.eval()
         self.decoder = self.model.get_decoder_wrapper(self.model)
         self.decoder.eval()