Exemplo n.º 1
0
def test_training_from_pretrained_with_head_replace(
    pipeline_test: Pipeline, datasource_test: DataSource, tmp_path: str
):
    training = pipeline_test.create_dataset(datasource_test)
    pipeline_test.create_vocabulary(VocabularyConfiguration(sources=[training]))
    configuration = TrainerConfiguration(
        data_bucketing=True, batch_size=2, num_epochs=5
    )
    output_dir = os.path.join(tmp_path, "output")
    results = pipeline_test.train(
        output=output_dir, trainer=configuration, training=training, quiet=True
    )

    trained = Pipeline.from_pretrained(results.model_path)
    trained.set_head(TestHead)
    trained.config.tokenizer.max_nr_of_sentences = 3
    copied = trained._make_copy()
    assert isinstance(copied.head, TestHead)
    assert copied.num_parameters == trained.num_parameters
    assert copied.num_trainable_parameters == trained.num_trainable_parameters
    copied_model_state = copied._model.state_dict()
    original_model_state = trained._model.state_dict()
    for key, value in copied_model_state.items():
        if "backbone" in key:
            assert torch.all(torch.eq(value, original_model_state[key]))
    assert copied.backbone.featurizer.tokenizer.max_nr_of_sentences == 3
Exemplo n.º 2
0
def test_training_with_logging(
    pipeline_test: Pipeline, datasource_test: DataSource, tmp_path: str
):
    training = pipeline_test.create_dataset(datasource_test)
    pipeline_test.create_vocabulary(VocabularyConfiguration(sources=[training]))

    configuration = TrainerConfiguration(
        data_bucketing=True, batch_size=2, num_epochs=5
    )
    output_dir = os.path.join(tmp_path, "output")
    pipeline_test.train(
        output=output_dir, trainer=configuration, training=training, quiet=True
    )

    assert os.path.exists(os.path.join(output_dir, "train.log"))
    with open(os.path.join(output_dir, "train.log")) as train_log:
        for line in train_log.readlines():
            assert "allennlp" in line

    assert logging.getLogger("allennlp").level == logging.ERROR
    assert logging.getLogger("biome").level == logging.INFO
Exemplo n.º 3
0
def test_training_with_data_bucketing(
    pipeline_test: Pipeline, datasource_test: DataSource, tmp_path: str
):
    lazy_ds = pipeline_test.create_dataset(datasource_test, lazy=True)
    non_lazy_ds = pipeline_test.create_dataset(datasource_test)

    pipeline_test.create_vocabulary(VocabularyConfiguration(sources=[lazy_ds]))

    configuration = TrainerConfiguration(
        data_bucketing=True, batch_size=2, num_epochs=5
    )
    pipeline_test.train(
        output=os.path.join(tmp_path, "output"),
        trainer=configuration,
        training=lazy_ds,
        validation=non_lazy_ds,
    )

    pipeline_test.train(
        output=os.path.join(tmp_path, "output"),
        trainer=configuration,
        training=non_lazy_ds,
        validation=lazy_ds,
    )