def test_postprocess_xlnet_answer(qa_test_data, tmp): qa_processor = QAProcessor(model_name="xlnet-base-cased") test_features = qa_processor.preprocess( qa_test_data["test_dataset"], is_training=False, max_question_length=16, max_seq_length=64, doc_stride=32, feature_cache_dir=tmp, ) qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp) predictions = qa_extractor.predict(test_features) qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE), ) qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE), unanswerable_exists=True, verbose_logging=True, )
def test_postprocess_xlnet_answer(qa_test_data, tmp_module): qa_processor = QAProcessor(model_name="xlnet-base-cased", cache_dir=tmp_module) test_features = qa_processor.preprocess( qa_test_data["test_dataset"], is_training=False, max_question_length=16, max_seq_length=64, doc_stride=32, feature_cache_dir=tmp_module, ) test_loader = dataloader_from_dataset(test_features, shuffle=False) qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp_module) predictions = qa_extractor.predict(test_loader) qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp_module, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp_module, CACHED_FEATURES_TEST_FILE), output_prediction_file=os.path.join(tmp_module, "qa_predictions.json"), output_nbest_file=os.path.join(tmp_module, "nbest_predictions.json"), output_null_log_odds_file=os.path.join(tmp_module, "null_odds.json"), ) qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp_module, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp_module, CACHED_FEATURES_TEST_FILE), unanswerable_exists=True, verbose_logging=True, output_prediction_file=os.path.join(tmp_module, "qa_predictions.json"), output_nbest_file=os.path.join(tmp_module, "nbest_predictions.json"), output_null_log_odds_file=os.path.join(tmp_module, "null_odds.json"), )
def test_AnswerExtractor(qa_test_data, tmp): # test bert qa_extractor_bert = AnswerExtractor(cache_dir=tmp) qa_extractor_bert.fit(qa_test_data["train_features_bert"], cache_model=True) # test saving fine-tuned model model_output_dir = os.path.join(tmp, "fine_tuned") assert os.path.exists(os.path.join(model_output_dir, "pytorch_model.bin")) assert os.path.exists(os.path.join(model_output_dir, "config.json")) qa_extractor_from_cache = AnswerExtractor( cache_dir=tmp, load_model_from_dir=model_output_dir) qa_extractor_from_cache.predict(qa_test_data["test_features_bert"]) qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp) qa_extractor_xlnet.fit(qa_test_data["train_features_xlnet"], cache_model=False) qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"]) qa_extractor_distilbert = AnswerExtractor( model_name="distilbert-base-uncased", cache_dir=tmp) qa_extractor_distilbert.fit(qa_test_data["train_features_distilbert"], cache_model=False) qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"])