def from_directory(cls, directory, device_id=None): logger.info('Loading training state from {}'.format(directory)) root_path = Path(directory) model_path = root_path / const.MODEL_FILE model = Model.create_from_file(model_path) if device_id is not None: model.to(device_id) optimizer_path = root_path / const.OPTIMIZER optimizer_dict = load_torch_file(str(optimizer_path)) optimizer = optimizer_class(optimizer_dict['name'])(model.parameters(), lr=0.0) optimizer.load_state_dict(optimizer_dict['state_dict']) trainer = cls(model, optimizer, checkpointer=None) trainer_path = root_path / const.TRAINER state = load_torch_file(str(trainer_path)) trainer.__dict__.update(state) return trainer
def retrieve_trainer(ModelClass, pipeline_options, model_options, vocabs, output_dir, device_id): """ Creates a Trainer object with an associated model. This object encapsulates the logic behind training the model and checkpointing. This method uses the received pipeline options to instantiate a Trainer object with the the requested model and hyperparameters. Args: ModelClass pipeline_options (Namespace): Generic training options resume (bool): Set to true if resuming an existing run. load_model (str): Directory containing model.torch for loading pre-created model. checkpoint_save (bool): Boolean indicating if snapshots should be saved after validation runs. warning: if false, will never save the model. checkpoint_keep_only_best (int): Indicates kiwi to keep the best `n` models. checkpoint_early_stop_patience (int): Stops training if metrics don't improve after `n` validation runs. checkpoint_validation_steps (int): Perform validation every `n` training steps. optimizer (string): The optimizer to be used in training. learning_rate (float): Starting learning rate. learning_rate_decay (float): Factor of learning rate decay. learning_rate_decay_start (int): Start decay after epoch `x`. log_interval (int): Log after `k` batches. model_options (Namespace): Model specific options. vocabs (dict): Vocab dictionary. output_dir (str or Path): Output directory for models and stats concerning training. device_id (int): The gpu id to be used in training. Set to negative to use cpu. Returns: Trainer """ if pipeline_options.resume: return Trainer.resume(local_path=output_dir, device_id=device_id) if pipeline_options.load_model: model = Model.create_from_file(pipeline_options.load_model) else: model = ModelClass.from_options(vocabs=vocabs, opts=model_options) checkpointer = Checkpoint( output_dir, pipeline_options.checkpoint_save, pipeline_options.checkpoint_keep_only_best, pipeline_options.checkpoint_early_stop_patience, pipeline_options.checkpoint_validation_steps, ) if isinstance(model, LinearWordQEClassifier): trainer = LinearWordQETrainer( model, model_options.training_algorithm, model_options.regularization_constant, checkpointer, ) else: # Set GPU or CPU; has to be before instantiating the optimizer model.to(device_id) # Optimizer OptimizerClass = optimizer_class(pipeline_options.optimizer) optimizer = OptimizerClass(model.parameters(), lr=pipeline_options.learning_rate) scheduler = None if 0.0 < pipeline_options.learning_rate_decay < 1.0: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=pipeline_options.learning_rate_decay, patience=pipeline_options.learning_rate_decay_start, verbose=True, mode="max", ) trainer = Trainer( model, optimizer, checkpointer, log_interval=pipeline_options.log_interval, scheduler=scheduler, ) return trainer