def test_postprocess_xlnet_answer(qa_test_data, tmp_module):
    qa_processor = QAProcessor(model_name="xlnet-base-cased",
                               cache_dir=tmp_module)
    test_features = qa_processor.preprocess(
        qa_test_data["test_dataset"],
        is_training=False,
        max_question_length=16,
        max_seq_length=64,
        doc_stride=32,
        feature_cache_dir=tmp_module,
    )
    test_loader = dataloader_from_dataset(test_features, shuffle=False)
    qa_extractor = AnswerExtractor(model_name="xlnet-base-cased",
                                   cache_dir=tmp_module)
    predictions = qa_extractor.predict(test_loader)

    qa_processor.postprocess(
        results=predictions,
        examples_file=os.path.join(tmp_module, CACHED_EXAMPLES_TEST_FILE),
        features_file=os.path.join(tmp_module, CACHED_FEATURES_TEST_FILE),
        output_prediction_file=os.path.join(tmp_module, "qa_predictions.json"),
        output_nbest_file=os.path.join(tmp_module, "nbest_predictions.json"),
        output_null_log_odds_file=os.path.join(tmp_module, "null_odds.json"),
    )

    qa_processor.postprocess(
        results=predictions,
        examples_file=os.path.join(tmp_module, CACHED_EXAMPLES_TEST_FILE),
        features_file=os.path.join(tmp_module, CACHED_FEATURES_TEST_FILE),
        unanswerable_exists=True,
        verbose_logging=True,
        output_prediction_file=os.path.join(tmp_module, "qa_predictions.json"),
        output_nbest_file=os.path.join(tmp_module, "nbest_predictions.json"),
        output_null_log_odds_file=os.path.join(tmp_module, "null_odds.json"),
    )
def test_postprocess_xlnet_answer(qa_test_data, tmp):
    qa_processor = QAProcessor(model_name="xlnet-base-cased")
    test_features = qa_processor.preprocess(
        qa_test_data["test_dataset"],
        is_training=False,
        max_question_length=16,
        max_seq_length=64,
        doc_stride=32,
        feature_cache_dir=tmp,
    )
    qa_extractor = AnswerExtractor(model_name="xlnet-base-cased",
                                   cache_dir=tmp)
    predictions = qa_extractor.predict(test_features)

    qa_processor.postprocess(
        results=predictions,
        examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
        features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
    )

    qa_processor.postprocess(
        results=predictions,
        examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
        features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
        unanswerable_exists=True,
        verbose_logging=True,
    )
def test_AnswerExtractor(qa_test_data, tmp_module):
    # bert
    qa_extractor_bert = AnswerExtractor(cache_dir=tmp_module)
    train_loader_bert = dataloader_from_dataset(
        qa_test_data["train_features_bert"])
    test_loader_bert = dataloader_from_dataset(
        qa_test_data["test_features_bert"], shuffle=False)
    qa_extractor_bert.fit(train_loader_bert, verbose=False, cache_model=True)

    # test saving fine-tuned model
    model_output_dir = os.path.join(tmp_module, "fine_tuned")
    assert os.path.exists(os.path.join(model_output_dir, "pytorch_model.bin"))
    assert os.path.exists(os.path.join(model_output_dir, "config.json"))

    qa_extractor_from_cache = AnswerExtractor(
        cache_dir=tmp_module, load_model_from_dir=model_output_dir)
    qa_extractor_from_cache.predict(test_loader_bert, verbose=False)

    # xlnet
    train_loader_xlnet = dataloader_from_dataset(
        qa_test_data["train_features_xlnet"])
    test_loader_xlnet = dataloader_from_dataset(
        qa_test_data["test_features_xlnet"], shuffle=False)
    qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased",
                                         cache_dir=tmp_module)
    qa_extractor_xlnet.fit(train_loader_xlnet,
                           verbose=False,
                           cache_model=False)
    qa_extractor_xlnet.predict(test_loader_xlnet, verbose=False)

    # distilbert
    train_loader_xlnet = dataloader_from_dataset(
        qa_test_data["train_features_distilbert"])
    test_loader_xlnet = dataloader_from_dataset(
        qa_test_data["test_features_distilbert"], shuffle=False)
    qa_extractor_distilbert = AnswerExtractor(
        model_name="distilbert-base-uncased", cache_dir=tmp_module)
    qa_extractor_distilbert.fit(train_loader_xlnet,
                                verbose=False,
                                cache_model=False)
    qa_extractor_distilbert.predict(test_loader_xlnet, verbose=False)