示例#1
0
def create_trainer_for_finding_lr(
    pipeline: Pipeline,
    trainer_config: TrainerConfiguration,
    training_data: InstancesDataset,
) -> GradientDescentTrainer:
    """Returns an AllenNLP Trainer used for the learning rate scan.

    Parameters
    ----------
    pipeline
        The pipeline with the model
    trainer_config
        A trainer configuration
    training_data
        The training data
    """
    prepare_environment(Params({}))

    if hasattr(training_data, "index_with"):
        training_data.index_with(pipeline.backbone.vocab)

    trainer_params = Params(
        helpers.sanitize_for_params(trainer_config.to_allennlp_trainer()))

    training_data_loader = create_dataloader(training_data,
                                             trainer_config.batch_size,
                                             trainer_config.data_bucketing)

    return Trainer.from_params(
        model=pipeline._model,
        data_loader=training_data_loader,
        params=trainer_params,
        serialization_dir=None,
    )
示例#2
0
def create_trainer_for_finding_lr(
    model: PipelineModel,
    trainer_config: TrainerConfiguration,
    training_data: InstancesDataset,
) -> GradientDescentTrainer:
    """Returns an AllenNLP Trainer used for the learning rate scan.

    Parameters
    ----------
    model
        The underlying model
    trainer_config
        A trainer configuration
    training_data
        The training data
    """
    prepare_environment(Params({}))

    trainer_params = Params(
        helpers.sanitize_for_params(trainer_config.to_allennlp_trainer()))

    training_data_loader = create_dataloader(training_data,
                                             trainer_config.batch_size,
                                             trainer_config.data_bucketing)

    return cast(
        "GradientDescentTrainer",
        Trainer.from_params(
            model=model,
            data_loader=training_data_loader,
            params=trainer_params,
            serialization_dir=None,
        ),
    )