Beispiel #1
0
def test_embeddings(
    model_cls,
    dataset_cls,
    model_kwargs,
    train_kwargs,
    embed_kwargs,
    model_gpu_config,
    pooling,
    gobbli_dir,
    request,
):
    """
    Ensure embedding models train and generate embeddings appropriately
    across a few example datasets.
    """
    if (model_cls in (USE, FastText, TfidfEmbedder)
            and pooling == gobbli.io.EmbedPooling.NONE):
        pytest.xfail(f"pooling is required for {model_cls.__name__}")

    # These combinations of model and dataset require a lot of memory
    if model_cls in (BERT,
                     Transformer) and dataset_cls in (NewsgroupsDataset, ):
        skip_if_low_resource(request.config)

    model = model_cls(
        data_dir=model_test_dir(model_cls),
        load_existing=True,
        **model_gpu_config,
        **model_kwargs,
    )

    model.build()
    ds = dataset_cls.load()

    embed_input = ds.embed_input(limit=50, pooling=pooling, **embed_kwargs)
    check_kwargs = {}
    if isinstance(model, RandomEmbedder):
        check_kwargs["expected_dimensionality"] = RandomEmbedder.DIMENSIONALITY
    if isinstance(model, Transformer):
        check_kwargs["max_seq_length"] = model_kwargs.get("max_seq_length")

    # For models which support generating embeddings without training
    if model_cls not in (FastText, ):
        # Verify we can generate embeddings without a trained checkpoint
        embed_output = model.embed(embed_input)
        check_embed_output(embed_input, embed_output, **check_kwargs)

    # Only these models support training for embeddings
    if model_cls in (BERT, FastText, Transformer):
        # Verify embedding runs with a trained checkpoint
        train_output = model.train(ds.train_input(limit=50, **train_kwargs))
        assert train_output.valid_loss is not None
        assert train_output.train_loss is not None
        assert 0 <= train_output.valid_accuracy <= 1

        validate_checkpoint(model_cls, train_output.checkpoint)

        embed_input.checkpoint = train_output.checkpoint
        embed_output = model.embed(embed_input)
        check_embed_output(embed_input, embed_output, **check_kwargs)
Beispiel #2
0
def test_classifier(
    model_cls,
    dataset_cls,
    model_kwargs,
    train_kwargs,
    predict_kwargs,
    model_gpu_config,
    gobbli_dir,
    request,
):
    """
    Ensure classifiers train and predict appropriately across a few example datasets.
    """
    # These combinations of model and dataset require a lot of memory
    if model_cls in (BERT, MTDNN, Transformer) and dataset_cls in (
        NewsgroupsDataset,
        MovieSummaryDataset,
    ):
        skip_if_low_resource(request.config)

    model = model_cls(
        data_dir=model_test_dir(model_cls),
        load_existing=True,
        **model_gpu_config,
        **model_kwargs,
    )

    model.build()
    ds = dataset_cls.load()

    train_input = ds.train_input(limit=50, **train_kwargs)
    if train_input.multilabel and model_cls in (BERT, MTDNN):
        pytest.xfail(
            f"model {model_cls.__name__} doesn't support multilabel classification"
        )

    # Verify training runs, results are sensible
    train_output = model.train(train_input)
    assert train_output.valid_loss is not None
    assert train_output.train_loss is not None
    assert 0 <= train_output.valid_accuracy <= 1

    validate_checkpoint(model.__class__, train_output.checkpoint)

    predict_input = ds.predict_input(limit=50, **predict_kwargs)

    if isinstance(model, FastText):
        # fastText requires a trained checkpoint for prediction
        pass
    else:
        # Verify prediction runs without a trained checkpoint
        predict_output = model.predict(predict_input)
        check_predict_output(train_output, predict_input, predict_output)

    # Verify prediction runs with a trained checkpoint
    predict_input.checkpoint = train_output.checkpoint
    predict_output = model.predict(predict_input)
    check_predict_output(train_output, predict_input, predict_output)
Beispiel #3
0
def test_bertmaskedlm_augment(model_gpu_config, gobbli_dir):
    model = BERTMaskedLM(data_dir=model_test_dir(BERTMaskedLM),
                         load_existing=True,
                         **model_gpu_config)
    model.build()

    times = 5
    new_texts = model.augment(["This is a test."], times=times)
    assert len(new_texts) == times
Beispiel #4
0
def test_marianmt_augment(model_gpu_config, gobbli_dir):
    # Don't go overboard with the languages here, since each
    # one requires a separate model (few hundred MB) to be downloaded
    target_languages = ["russian", "french"]
    model = MarianMT(
        data_dir=model_test_dir(MarianMT),
        load_existing=True,
        target_languages=target_languages,
        **model_gpu_config,
    )
    model.build()

    # Can't augment more times than target languages
    invalid_num_times = len(target_languages) + 1
    with pytest.raises(ValueError):
        model.augment(["This is a test."], times=invalid_num_times)

    valid_num_times = len(target_languages)
    new_texts = model.augment(["This is a test."], times=valid_num_times)
    assert len(new_texts) == valid_num_times