Exemple #1
0
def test_sequence_classification_explain_on_cls_name_not_in_dict():
    explainer_string = "I love you , I like you"
    seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL,
                                                    DISTILBERT_TOKENIZER)
    seq_explainer._run(explainer_string, class_name="UNKNOWN")
    assert seq_explainer.selected_index == 1
    assert seq_explainer.predicted_class_index == 1
Exemple #2
0
def test_sequence_classification_predicted_class_name_no_id2label_defaults_idx(
):
    explainer_string = "I love you , I like you"
    seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL,
                                                    DISTILBERT_TOKENIZER)
    seq_explainer.id2label = {"test": "value"}
    seq_explainer._run(explainer_string)
    assert seq_explainer.predicted_class_name == 1
Exemple #3
0
def test_sequence_classification_explain_on_cls_name_with_custom_labels():
    explainer_string = "I love you , I like you"
    seq_explainer = SequenceClassificationExplainer(
        DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["sad", "happy"])
    seq_explainer._run(explainer_string, class_name="sad")
    assert seq_explainer.predicted_class_index == 1
    assert seq_explainer.predicted_class_index != seq_explainer.selected_index
    assert seq_explainer.predicted_class_name != seq_explainer.id2label[
        seq_explainer.selected_index]
    assert seq_explainer.predicted_class_name != "sad"
    assert seq_explainer.predicted_class_name == "happy"
Exemple #4
0
def test_sequence_classification_explain_on_cls_name():
    explainer_string = "I love you , I like you"
    seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL,
                                                    DISTILBERT_TOKENIZER)
    seq_explainer._run(explainer_string, class_name="NEGATIVE")
    assert seq_explainer.predicted_class_index == 1
    assert seq_explainer.predicted_class_index != seq_explainer.selected_index
    assert seq_explainer.predicted_class_name != seq_explainer.id2label[
        seq_explainer.selected_index]
    assert seq_explainer.predicted_class_name != "NEGATIVE"
    assert seq_explainer.predicted_class_name == "POSITIVE"
def test_sequence_classification_explain_callable():
    explainer_string = "I love you , I like you"
    seq_explainer = SequenceClassificationExplainer(explainer_string,
                                                    DISTILBERT_MODEL,
                                                    DISTILBERT_TOKENIZER)

    seq_explainer._run()
    run_method_predicted_index = seq_explainer.predicted_class_index

    seq_explainer()
    call_method_predicted_index = seq_explainer.predicted_class_index

    assert call_method_predicted_index == run_method_predicted_index
Exemple #6
0
def test_sequence_classification_run_text_given():
    seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL,
                                                    DISTILBERT_TOKENIZER)
    word_attributions = seq_explainer._run("I love you, I just love you")
    assert isinstance(word_attributions, list)

    actual_tokens = [token for token, _ in word_attributions]
    expected_tokens = [
        "[CLS]",
        "i",
        "love",
        "you",
        ",",
        "i",
        "just",
        "love",
        "you",
        "[SEP]",
    ]
    assert actual_tokens == expected_tokens
def test_sequence_classification_no_text_given():
    explainer_string = "I love you , I hate you"
    seq_explainer = SequenceClassificationExplainer(explainer_string,
                                                    DISTILBERT_MODEL,
                                                    DISTILBERT_TOKENIZER)
    attributions = seq_explainer._run()
    assert isinstance(attributions, LIGAttributions)

    actual_tokens = [token for token, _ in attributions.word_attributions]
    expected_tokens = [
        "[CLS]",
        "i",
        "love",
        "you",
        ",",
        "i",
        "hate",
        "you",
        "[SEP]",
    ]
    assert actual_tokens == expected_tokens
def test_sequence_classification_run_text_given_bert():
    explainer_string = "I love you , I hate you"
    seq_explainer = SequenceClassificationExplainer(explainer_string,
                                                    BERT_MODEL, BERT_TOKENIZER)
    attributions = seq_explainer._run("I love you, I just love you")
    assert isinstance(attributions, LIGAttributions)
    assert seq_explainer.accepts_position_ids == True

    actual_tokens = [token for token, _ in attributions.word_attributions]
    expected_tokens = [
        "[CLS]",
        "i",
        "love",
        "you",
        ",",
        "i",
        "just",
        "love",
        "you",
        "[SEP]",
    ]
    assert actual_tokens == expected_tokens