def test_question_answering_visualize():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is Bob"
    qa_explainer(explainer_question, explainer_text)
    qa_explainer.visualize()
def test_question_answering_encode():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)

    _input = "this is a sample of text to be encoded"
    tokens = qa_explainer.encode(_input)
    assert isinstance(tokens, list)
    assert tokens[0] != qa_explainer.cls_token_id
    assert tokens[-1] != qa_explainer.sep_token_id
    assert len(tokens) >= len(_input.split(" "))
Пример #3
0
def test_question_answering_visualize_save_append_html_file_ending():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is Bob"
    qa_explainer(explainer_question, explainer_text)

    html_filename = "./test/qa_test"
    qa_explainer.visualize(html_filename)
    assert os.path.exists(html_filename + ".html")
    os.remove(html_filename + ".html")
def test_question_answering_decode():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is bob"
    input_ids, _, _ = qa_explainer._make_input_reference_pair(
        explainer_question, explainer_text)
    decoded = qa_explainer.decode(input_ids)
    assert decoded[0] == qa_explainer.tokenizer.cls_token
    assert decoded[-1] == qa_explainer.tokenizer.sep_token
    assert (" ".join(decoded[1:-1]) == explainer_question.lower() + " [SEP] " +
            explainer_text.lower())
def test_question_answering_explainer_init_attribution_type_error():
    with pytest.raises(AttributionTypeNotSupportedError):
        QuestionAnsweringExplainer(
            DISTILBERT_QA_MODEL,
            DISTILBERT_QA_TOKENIZER,
            attribution_type="UNSUPPORTED",
        )
Пример #6
0
def test_question_answering_word_attributions_input_ids_not_calculated():
    qa_explainer = QuestionAnsweringExplainer(
        DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER
    )

    with pytest.raises(ValueError):
        qa_explainer.word_attributions
Пример #7
0
def test_question_answering_explainer_init_distilbert():
    qa_explainer = QuestionAnsweringExplainer(
        DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER
    )
    assert qa_explainer.attribution_type == "lig"
    assert qa_explainer.attributions is None
    assert qa_explainer.position == 0
def test_question_answering_end_pos():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is Bob"
    qa_explainer(explainer_question, explainer_text)
    end_pos = qa_explainer.end_pos
    assert end_pos == 10
def test_question_answering_predicted_answer():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is Bob"
    qa_explainer(explainer_question, explainer_text)
    predicted_answer = qa_explainer.predicted_answer
    assert predicted_answer == "bob"
def test_question_answering_word_attributions():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is bob"
    word_attributions = qa_explainer(explainer_question, explainer_text)
    assert isinstance(word_attributions, dict)
    assert "start" in word_attributions.keys()
    assert "end" in word_attributions.keys()
    assert len(word_attributions["start"]) == len(word_attributions["end"])
def test_question_answering_start_pos_input_ids_not_calculated():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    with pytest.raises(InputIdsNotCalculatedError):
        qa_explainer.start_pos
Пример #12
0
def xtest_question_answering_custom_internal_batch_size():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is Bob"
    qa_explainer(explainer_question, explainer_text, internal_batch_size=1)
Пример #13
0
def xtest_question_answering_custom_steps():
    qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL,
                                              DISTILBERT_QA_TOKENIZER)
    explainer_question = "what is his name ?"
    explainer_text = "his name is Bob"
    qa_explainer(explainer_question, explainer_text, n_steps=1)