Exemple #1
0
def test_zero_shot_explainer_visualize():
    zero_shot_explainer = ZeroShotClassificationExplainer(
        DISTILBERT_MNLI_MODEL,
        DISTILBERT_MNLI_TOKENIZER,
    )

    zero_shot_explainer(
        "I have a problem with my iphone that needs to be resolved asap!!",
        labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
    )
    zero_shot_explainer.visualize()
Exemple #2
0
def test_zero_shot_explainer_visualize_save():
    zero_shot_explainer = ZeroShotClassificationExplainer(
        DISTILBERT_MNLI_MODEL,
        DISTILBERT_MNLI_TOKENIZER,
    )

    zero_shot_explainer(
        "I have a problem with my iphone that needs to be resolved asap!!",
        labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
    )
    html_filename = "./test/zero_test.html"
    zero_shot_explainer.visualize(html_filename)
    assert os.path.exists(html_filename)
    os.remove(html_filename)
Exemple #3
0
def test_zero_shot_explainer_init_attribution_type_error():
    with pytest.raises(AttributionTypeNotSupportedError):
        ZeroShotClassificationExplainer(
            DISTILBERT_MNLI_MODEL,
            DISTILBERT_MNLI_TOKENIZER,
            attribution_type="UNSUPPORTED",
        )
Exemple #4
0
def test_zero_shot_explainer_call_word_attributions_early_raises_error():
    with pytest.raises(ValueError):
        zero_shot_explainer = ZeroShotClassificationExplainer(
            DISTILBERT_MNLI_MODEL,
            DISTILBERT_MNLI_TOKENIZER,
        )

        zero_shot_explainer.word_attributions
Exemple #5
0
def test_zero_shot_model_lowercase_entailment():
    with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {
            "entailment": 0,
            "l2": 1,
            "l3": 2
    }):
        ZeroShotClassificationExplainer(
            DISTILBERT_MNLI_MODEL,
            DISTILBERT_MNLI_TOKENIZER,
        )
Exemple #6
0
def test_zero_shot_explainer_init_distilbert():
    zero_shot_explainer = ZeroShotClassificationExplainer(
        DISTILBERT_MNLI_MODEL,
        DISTILBERT_MNLI_TOKENIZER,
    )

    assert zero_shot_explainer.attribution_type == "lig"
    assert zero_shot_explainer.attributions == []
    assert zero_shot_explainer.label_exists is True
    assert zero_shot_explainer.entailment_key == "ENTAILMENT"
Exemple #7
0
def test_zero_shot_model_does_not_have_entailment_label():
    with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {
            "l1": 0,
            "l2": 1,
            "l3": 2
    }):
        with pytest.raises(ValueError):
            ZeroShotClassificationExplainer(
                DISTILBERT_MNLI_MODEL,
                DISTILBERT_MNLI_TOKENIZER,
            )
Exemple #8
0
def test_zero_shot_explainer_word_attributions():
    zero_shot_explainer = ZeroShotClassificationExplainer(
        DISTILBERT_MNLI_MODEL,
        DISTILBERT_MNLI_TOKENIZER,
    )
    labels = ["urgent", "phone", "tablet", "computer"]
    word_attributions = zero_shot_explainer(
        "I have a problem with my iphone that needs to be resolved asap!!",
        labels=labels,
    )
    assert isinstance(word_attributions, dict)
    for label in labels:
        assert label in word_attributions.keys()
Exemple #9
0
def test_zero_shot_explainer_word_attributions_include_hypothesis():
    zero_shot_explainer = ZeroShotClassificationExplainer(
        DISTILBERT_MNLI_MODEL,
        DISTILBERT_MNLI_TOKENIZER,
    )
    labels = ["urgent", "phone", "tablet", "computer"]
    word_attributions_with_hyp = zero_shot_explainer(
        "I have a problem with my iphone that needs to be resolved asap!!",
        labels=labels,
        include_hypothesis=True,
    )
    word_attributions_without_hyp = zero_shot_explainer(
        "I have a problem with my iphone that needs to be resolved asap!!",
        labels=labels,
        include_hypothesis=False,
    )

    for label in labels:
        assert len(word_attributions_with_hyp[label]) > len(
            word_attributions_without_hyp[label])
Exemple #10
0
def test_zero_shot_explainer_no_entailment_label(mock_method):
    with pytest.raises(ValueError):
        ZeroShotClassificationExplainer(
            DISTILBERT_MNLI_MODEL,
            DISTILBERT_MNLI_TOKENIZER,
        )