def test_embeddings( model_cls, dataset_cls, model_kwargs, train_kwargs, embed_kwargs, model_gpu_config, pooling, gobbli_dir, request, ): """ Ensure embedding models train and generate embeddings appropriately across a few example datasets. """ if (model_cls in (USE, FastText, TfidfEmbedder) and pooling == gobbli.io.EmbedPooling.NONE): pytest.xfail(f"pooling is required for {model_cls.__name__}") # These combinations of model and dataset require a lot of memory if model_cls in (BERT, Transformer) and dataset_cls in (NewsgroupsDataset, ): skip_if_low_resource(request.config) model = model_cls( data_dir=model_test_dir(model_cls), load_existing=True, **model_gpu_config, **model_kwargs, ) model.build() ds = dataset_cls.load() embed_input = ds.embed_input(limit=50, pooling=pooling, **embed_kwargs) check_kwargs = {} if isinstance(model, RandomEmbedder): check_kwargs["expected_dimensionality"] = RandomEmbedder.DIMENSIONALITY if isinstance(model, Transformer): check_kwargs["max_seq_length"] = model_kwargs.get("max_seq_length") # For models which support generating embeddings without training if model_cls not in (FastText, ): # Verify we can generate embeddings without a trained checkpoint embed_output = model.embed(embed_input) check_embed_output(embed_input, embed_output, **check_kwargs) # Only these models support training for embeddings if model_cls in (BERT, FastText, Transformer): # Verify embedding runs with a trained checkpoint train_output = model.train(ds.train_input(limit=50, **train_kwargs)) assert train_output.valid_loss is not None assert train_output.train_loss is not None assert 0 <= train_output.valid_accuracy <= 1 validate_checkpoint(model_cls, train_output.checkpoint) embed_input.checkpoint = train_output.checkpoint embed_output = model.embed(embed_input) check_embed_output(embed_input, embed_output, **check_kwargs)
def test_classifier( model_cls, dataset_cls, model_kwargs, train_kwargs, predict_kwargs, model_gpu_config, gobbli_dir, request, ): """ Ensure classifiers train and predict appropriately across a few example datasets. """ # These combinations of model and dataset require a lot of memory if model_cls in (BERT, MTDNN, Transformer) and dataset_cls in ( NewsgroupsDataset, MovieSummaryDataset, ): skip_if_low_resource(request.config) model = model_cls( data_dir=model_test_dir(model_cls), load_existing=True, **model_gpu_config, **model_kwargs, ) model.build() ds = dataset_cls.load() train_input = ds.train_input(limit=50, **train_kwargs) if train_input.multilabel and model_cls in (BERT, MTDNN): pytest.xfail( f"model {model_cls.__name__} doesn't support multilabel classification" ) # Verify training runs, results are sensible train_output = model.train(train_input) assert train_output.valid_loss is not None assert train_output.train_loss is not None assert 0 <= train_output.valid_accuracy <= 1 validate_checkpoint(model.__class__, train_output.checkpoint) predict_input = ds.predict_input(limit=50, **predict_kwargs) if isinstance(model, FastText): # fastText requires a trained checkpoint for prediction pass else: # Verify prediction runs without a trained checkpoint predict_output = model.predict(predict_input) check_predict_output(train_output, predict_input, predict_output) # Verify prediction runs with a trained checkpoint predict_input.checkpoint = train_output.checkpoint predict_output = model.predict(predict_input) check_predict_output(train_output, predict_input, predict_output)
def test_bertmaskedlm_augment(model_gpu_config, gobbli_dir): model = BERTMaskedLM(data_dir=model_test_dir(BERTMaskedLM), load_existing=True, **model_gpu_config) model.build() times = 5 new_texts = model.augment(["This is a test."], times=times) assert len(new_texts) == times
def test_marianmt_augment(model_gpu_config, gobbli_dir): # Don't go overboard with the languages here, since each # one requires a separate model (few hundred MB) to be downloaded target_languages = ["russian", "french"] model = MarianMT( data_dir=model_test_dir(MarianMT), load_existing=True, target_languages=target_languages, **model_gpu_config, ) model.build() # Can't augment more times than target languages invalid_num_times = len(target_languages) + 1 with pytest.raises(ValueError): model.augment(["This is a test."], times=invalid_num_times) valid_num_times = len(target_languages) new_texts = model.augment(["This is a test."], times=valid_num_times) assert len(new_texts) == valid_num_times