def test_compare_test_accuracy(
    training_configuration: _MentorTrainAndTestConfiguration,
    compare_configuration: _MentorTrainAndTestConfiguration,
    tmpdir,
    shared_root: str,
    example: str,
    test_set_file: str,
):
    mentor = load_mentor_csv(
        fixture_mentor_data(training_configuration.mentor_id, "data.csv"))
    test_set = load_test_csv(
        fixture_mentor_data(training_configuration.mentor_id, test_set_file
                            or "test.csv"))
    data = {"data": {"mentor": mentor.to_dict()}}
    responses.add(responses.POST,
                  "http://graphql/graphql",
                  json=data,
                  status=200)
    lr_train = (ClassifierFactory().new_training(
        mentor=training_configuration.mentor_id,
        shared_root=shared_root,
        data_path=tmpdir,
        arch=training_configuration.arch,
    ).train(shared_root))
    hf_train = (ClassifierFactory().new_training(
        mentor=compare_configuration.mentor_id,
        shared_root=shared_root,
        data_path=tmpdir,
        arch=compare_configuration.arch,
    ).train(shared_root))
    assert hf_train.accuracy >= lr_train.accuracy

    hf_classifier = ClassifierFactory().new_prediction(
        mentor=compare_configuration.mentor_id,
        shared_root=shared_root,
        data_path=tmpdir,
        arch=compare_configuration.arch,
    )
    hf_test_results = run_model_against_testset_ignore_confidence(
        hf_classifier, test_set, shared_root)
    hf_test_accuracy = hf_test_results.passing_tests / len(
        hf_test_results.results)
    lr_classifier = ClassifierFactory().new_prediction(
        mentor=training_configuration.mentor_id,
        shared_root=shared_root,
        data_path=tmpdir,
        arch=training_configuration.arch,
    )
    lr_test_results = run_model_against_testset_ignore_confidence(
        lr_classifier, test_set, shared_root)
    lr_test_accuracy = lr_test_results.passing_tests / len(
        lr_test_results.results)
    assert lr_test_accuracy <= hf_test_accuracy
    hf_result = hf_classifier.evaluate(example, shared_root)
    lr_result = lr_classifier.evaluate(example, shared_root)
    assert hf_result.highest_confidence >= lr_result.highest_confidence
Ejemplo n.º 2
0
def _test_gets_off_topic(
    monkeypatch,
    data_root: str,
    shared_root: str,
    mentor_id: str,
    question: str,
    expected_answer_id: str,
    expected_answer: str,
    expected_media: List[Media],
):
    monkeypatch.setenv("OFF_TOPIC_THRESHOLD", "1.0")  # everything is offtopic
    with open(fixture_path("graphql/{}.json".format(mentor_id))) as f:
        data = json.load(f)
    responses.add(responses.POST,
                  "http://graphql/graphql",
                  json=data,
                  status=200)
    _ensure_trained(mentor_id, shared_root, data_root)
    classifier = ClassifierFactory().new_prediction(mentor=mentor_id,
                                                    shared_root=shared_root,
                                                    data_path=data_root)
    result = classifier.evaluate(question, shared_root)
    assert result.highest_confidence < get_off_topic_threshold()
    assert result.answer_id == expected_answer_id
    assert result.answer_text == expected_answer
    assert result.answer_media == expected_media
    assert result.feedback_id is not None
Ejemplo n.º 3
0
def test_gets_answer_for_exact_match_and_paraphrases(
    data_root: str,
    shared_root: str,
    mentor_id: str,
    question: str,
    expected_answer_id: str,
    expected_answer: str,
    expected_media: List[Media],
):
    with open(fixture_path("graphql/{}.json".format(mentor_id))) as f:
        data = json.load(f)
    responses.add(responses.POST,
                  "http://graphql/graphql",
                  json=data,
                  status=200)
    _ensure_trained(mentor_id, shared_root, data_root)
    classifier = ClassifierFactory().new_prediction(mentor_id, shared_root,
                                                    data_root)
    result = classifier.evaluate(question, shared_root)
    assert result.answer_id == expected_answer_id
    assert result.answer_text == expected_answer
    assert result.answer_media == expected_media
    assert result.highest_confidence == 1
    assert result.feedback_id is not None