def from_directory(cls, directory, device_id=None):
        logger.info('Loading training state from {}'.format(directory))
        root_path = Path(directory)

        model_path = root_path / const.MODEL_FILE
        model = Model.create_from_file(model_path)

        if device_id is not None:
            model.to(device_id)

        optimizer_path = root_path / const.OPTIMIZER
        optimizer_dict = load_torch_file(str(optimizer_path))

        optimizer = optimizer_class(optimizer_dict['name'])(model.parameters(),
                                                            lr=0.0)
        optimizer.load_state_dict(optimizer_dict['state_dict'])

        trainer = cls(model, optimizer, checkpointer=None)
        trainer_path = root_path / const.TRAINER
        state = load_torch_file(str(trainer_path))
        trainer.__dict__.update(state)
        return trainer
def retrieve_trainer(ModelClass, pipeline_options, model_options, vocabs,
                     output_dir, device_id):
    """
    Creates a Trainer object with an associated model.

    This object encapsulates the logic behind training the model and
    checkpointing. This method uses the received pipeline options to
    instantiate a Trainer object with the the requested model and
    hyperparameters.

    Args:
        ModelClass
        pipeline_options (Namespace): Generic training options
            resume (bool): Set to true if resuming an existing run.
            load_model (str): Directory containing model.torch for loading
                pre-created model.
            checkpoint_save (bool): Boolean indicating if snapshots should be
                saved after validation runs. warning: if false, will never save
                the model.
            checkpoint_keep_only_best (int): Indicates kiwi to keep the best
                `n` models.
            checkpoint_early_stop_patience (int): Stops training if metrics
                don't improve after `n` validation runs.
            checkpoint_validation_steps (int): Perform validation every `n`
                training steps.
            optimizer (string): The optimizer to be used in training.
            learning_rate (float): Starting learning rate.
            learning_rate_decay (float): Factor of learning rate decay.
            learning_rate_decay_start (int): Start decay after epoch `x`.
            log_interval (int): Log after `k` batches.
        model_options (Namespace): Model specific options.
        vocabs (dict): Vocab dictionary.
        output_dir (str or Path): Output directory for models and stats
            concerning training.
        device_id (int): The gpu id to be used in training. Set to negative
            to use cpu.
    Returns:
        Trainer

    """

    if pipeline_options.resume:
        return Trainer.resume(local_path=output_dir, device_id=device_id)

    if pipeline_options.load_model:
        model = Model.create_from_file(pipeline_options.load_model)
    else:
        model = ModelClass.from_options(vocabs=vocabs, opts=model_options)

    checkpointer = Checkpoint(
        output_dir,
        pipeline_options.checkpoint_save,
        pipeline_options.checkpoint_keep_only_best,
        pipeline_options.checkpoint_early_stop_patience,
        pipeline_options.checkpoint_validation_steps,
    )

    if isinstance(model, LinearWordQEClassifier):
        trainer = LinearWordQETrainer(
            model,
            model_options.training_algorithm,
            model_options.regularization_constant,
            checkpointer,
        )
    else:
        # Set GPU or CPU; has to be before instantiating the optimizer
        model.to(device_id)

        # Optimizer
        OptimizerClass = optimizer_class(pipeline_options.optimizer)
        optimizer = OptimizerClass(model.parameters(),
                                   lr=pipeline_options.learning_rate)
        scheduler = None
        if 0.0 < pipeline_options.learning_rate_decay < 1.0:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=pipeline_options.learning_rate_decay,
                patience=pipeline_options.learning_rate_decay_start,
                verbose=True,
                mode="max",
            )

        trainer = Trainer(
            model,
            optimizer,
            checkpointer,
            log_interval=pipeline_options.log_interval,
            scheduler=scheduler,
        )
    return trainer