Beispiel #1
0
 def __init__(self, config, runner_type=RunnerType.PyTorch):
     self.config = config
     self.runner = create_runner(self.config, runner_type)
     self.vocab = read_vocab(self.config.get("vocab_path"))
     self.load_dataset()
     self.runner.load_model()
     self.read_expected_outputs()
Beispiel #2
0
 def load_model(self):
     vocab_len = len(read_vocab(self.config.get('vocab_path')))
     self.model = Im2latexModel(self.config.get('backbone_type', 'resnet'),
                                self.config.get('backbone_config'),
                                vocab_len, self.config.get('head', {}))
     self.device = self.config.get('device', 'cpu')
     self.model.load_weights(self.config.get("model_path"),
                             map_location=self.device)
     self.model = self.model.to(self.device)
     self.model.eval()
Beispiel #3
0
 def __init__(self, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.transform = create_list_of_transforms(config.get('transforms_list'))
     self.model = Im2latexModel(config.get('backbone_type', 'resnet'), config.get(
         'backbone_config'), len(self.vocab), config.get('head', {}))
     if self.model_path is not None:
         self.model.load_weights(self.model_path, map_location=config.get('map_location', 'cpu'))
     self.model.eval()
     self.device = config.get('device', 'cpu')
     self.model = self.model.to(self.device)
 def __init__(self, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.model = Im2latexModel(config.get('backbone_type', 'resnet'),
                                config.get('backbone_config'),
                                len(self.vocab), config.get('head', {}))
     self.model.eval()
     if self.model_path is not None:
         self.model.load_weights(self.model_path)
     self.img_for_export = torch.rand(
         self.config.get("input_shape_decoder"))
     self.encoder = self.model.get_encoder_wrapper(self.model)
     self.encoder.eval()
     self.decoder = self.model.get_decoder_wrapper(self.model)
     self.decoder.eval()
    def __init__(self, work_dir, config):
        self.config = config
        self.model_path = config.get('model_path')
        self.train_paths = config.get('train_paths')
        self.val_path = config.get('val_path')
        self.vocab = read_vocab(config.get('vocab_path'))
        self.train_transforms_list = config.get('train_transforms_list')
        self.val_transforms_list = config.get('val_transforms_list')
        self.total_epochs = config.get('epochs', 30)
        self.learing_rate = config.get('learning_rate', 1e-3)
        self.clip = config.get('clip_grad', 5.0)
        self.work_dir = os.path.abspath(work_dir)
        self.save_dir = os.path.join(
            self.work_dir, config.get("save_dir", "model_checkpoints"))
        self.val_results_path = os.path.join(self.work_dir, "val_results")
        self.step = 0
        self.global_step = 0
        self._test_steps = config.get('_test_steps', 1e18)
        self.epoch = 1
        self.best_val_loss = 1e18
        self.best_val_accuracy = 0.0
        self.best_val_loss_test = 1e18
        self.best_val_accuracy_test = 0.0
        self.print_freq = config.get('print_freq', 16)
        self.save_freq = config.get('save_freq', 2000)
        self.val_freq = config.get('val_freq', 5000)
        self.logs_path = os.path.join(self.work_dir,
                                      config.get("log_path", "logs"))
        self.writer = SummaryWriter(self.logs_path)
        self.device = config.get('device', 'cpu')
        self.writer.add_text("General info", pformat(config))
        self.create_dirs()
        self.load_dataset()
        self.model = Im2latexModel(config.get('backbone_type', 'resnet'),
                                   config.get('backbone_config'),
                                   len(self.vocab), config.get('head', {}))
        if self.model_path is not None:
            self.model.load_weights(self.model_path, map_location=self.device)

        self.optimizer = getattr(optim,
                                 config.get('optimizer',
                                            "Adam"))(self.model.parameters(),
                                                     self.learing_rate)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer)
        self.model = self.model.to(self.device)
        self.time = get_timestamp()