def test_sequence_classification_explain_on_cls_name_not_in_dict(): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) seq_explainer._run(explainer_string, class_name="UNKNOWN") assert seq_explainer.selected_index == 1 assert seq_explainer.predicted_class_index == 1
def test_sequence_classification_predicted_class_name_no_id2label_defaults_idx( ): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) seq_explainer.id2label = {"test": "value"} seq_explainer._run(explainer_string) assert seq_explainer.predicted_class_name == 1
def test_sequence_classification_explain_on_cls_name_with_custom_labels(): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer( DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["sad", "happy"]) seq_explainer._run(explainer_string, class_name="sad") assert seq_explainer.predicted_class_index == 1 assert seq_explainer.predicted_class_index != seq_explainer.selected_index assert seq_explainer.predicted_class_name != seq_explainer.id2label[ seq_explainer.selected_index] assert seq_explainer.predicted_class_name != "sad" assert seq_explainer.predicted_class_name == "happy"
def test_sequence_classification_explain_on_cls_name(): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) seq_explainer._run(explainer_string, class_name="NEGATIVE") assert seq_explainer.predicted_class_index == 1 assert seq_explainer.predicted_class_index != seq_explainer.selected_index assert seq_explainer.predicted_class_name != seq_explainer.id2label[ seq_explainer.selected_index] assert seq_explainer.predicted_class_name != "NEGATIVE" assert seq_explainer.predicted_class_name == "POSITIVE"
def test_sequence_classification_explain_callable(): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer(explainer_string, DISTILBERT_MODEL, DISTILBERT_TOKENIZER) seq_explainer._run() run_method_predicted_index = seq_explainer.predicted_class_index seq_explainer() call_method_predicted_index = seq_explainer.predicted_class_index assert call_method_predicted_index == run_method_predicted_index
def test_sequence_classification_run_text_given(): seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) word_attributions = seq_explainer._run("I love you, I just love you") assert isinstance(word_attributions, list) actual_tokens = [token for token, _ in word_attributions] expected_tokens = [ "[CLS]", "i", "love", "you", ",", "i", "just", "love", "you", "[SEP]", ] assert actual_tokens == expected_tokens
def test_sequence_classification_no_text_given(): explainer_string = "I love you , I hate you" seq_explainer = SequenceClassificationExplainer(explainer_string, DISTILBERT_MODEL, DISTILBERT_TOKENIZER) attributions = seq_explainer._run() assert isinstance(attributions, LIGAttributions) actual_tokens = [token for token, _ in attributions.word_attributions] expected_tokens = [ "[CLS]", "i", "love", "you", ",", "i", "hate", "you", "[SEP]", ] assert actual_tokens == expected_tokens
def test_sequence_classification_run_text_given_bert(): explainer_string = "I love you , I hate you" seq_explainer = SequenceClassificationExplainer(explainer_string, BERT_MODEL, BERT_TOKENIZER) attributions = seq_explainer._run("I love you, I just love you") assert isinstance(attributions, LIGAttributions) assert seq_explainer.accepts_position_ids == True actual_tokens = [token for token, _ in attributions.word_attributions] expected_tokens = [ "[CLS]", "i", "love", "you", ",", "i", "just", "love", "you", "[SEP]", ] assert actual_tokens == expected_tokens