Example #1
0
def test_token_classification_predicted_class_names():
    explainer_string = "We visited Paris during the weekend"
    ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL,
                                                 DISTILBERT_TOKENIZER)
    ner_explainer._run(explainer_string)
    ground_truths = ["O", "O", "O", "B-LOC", "O", "O", "O", "O"]

    assert len(ground_truths) == len(ner_explainer.predicted_class_names)

    for i, class_id in enumerate(ner_explainer.predicted_class_names):
        assert ground_truths[i] == class_id
Example #2
0
def test_token_classification_predicted_class_names_no_id2label_defaults_idx():
    explainer_string = "We visited Paris during the weekend"
    ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL,
                                                 DISTILBERT_TOKENIZER)
    ner_explainer.id2label = {"test": "value"}
    ner_explainer._run(explainer_string)
    class_labels = list(range(9))

    assert len(ner_explainer.predicted_class_names) == 8

    for class_name in ner_explainer.predicted_class_names:
        assert class_name in class_labels
Example #3
0
def test_token_classification_run_text_given():
    ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL,
                                                 DISTILBERT_TOKENIZER)
    word_attributions = ner_explainer._run(
        "We visited Paris during the weekend")
    assert isinstance(word_attributions, dict)

    actual_tokens = list(word_attributions.keys())
    expected_tokens = [
        "[CLS]",
        "We",
        "visited",
        "Paris",
        "during",
        "the",
        "weekend",
        "[SEP]",
    ]
    assert actual_tokens == expected_tokens