def test_sequence_classification_explain_callable(): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer(explainer_string, MODEL, TOKENIZER) seq_explainer.run() run_method_predicted_index = seq_explainer.predicted_class_index seq_explainer() call_method_predicted_index = seq_explainer.predicted_class_index assert call_method_predicted_index == run_method_predicted_index
def test_sequence_classification_explain_on_cls_name(): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer(explainer_string, MODEL, TOKENIZER) attributions = seq_explainer.run(class_name="NEGATIVE") assert seq_explainer.predicted_class_index == 1 assert seq_explainer.predicted_class_index != seq_explainer.selected_index assert ( seq_explainer.predicted_class_name != seq_explainer.id2label[seq_explainer.selected_index] ) assert seq_explainer.predicted_class_name != "NEGATIVE" assert seq_explainer.predicted_class_name == "POSITIVE"
def test_sequence_classification_no_text_given(): explainer_string = "I love you , I hate you" seq_explainer = SequenceClassificationExplainer(explainer_string, MODEL, TOKENIZER) attributions = seq_explainer.run() assert isinstance(attributions, LIGAttributions) actual_tokens = [token for token, _ in attributions.word_attributions] expected_tokens = [ "BOS_TOKEN", "I", "love", "you", ",", "I", "hate", "you", "EOS_TOKEN", ] assert actual_tokens == expected_tokens
def test_sequence_classification_explain_on_cls_name_not_in_dict(): explainer_string = "I love you , I like you" seq_explainer = SequenceClassificationExplainer(explainer_string, MODEL, TOKENIZER) attributions = seq_explainer.run(class_name="UNKNOWN") assert seq_explainer.selected_index == 1 assert seq_explainer.predicted_class_index == 1