def test_token_classification_predicted_class_names(): explainer_string = "We visited Paris during the weekend" ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) ner_explainer._run(explainer_string) ground_truths = ["O", "O", "O", "B-LOC", "O", "O", "O", "O"] assert len(ground_truths) == len(ner_explainer.predicted_class_names) for i, class_id in enumerate(ner_explainer.predicted_class_names): assert ground_truths[i] == class_id
def test_token_classification_predicted_class_names_no_id2label_defaults_idx(): explainer_string = "We visited Paris during the weekend" ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) ner_explainer.id2label = {"test": "value"} ner_explainer._run(explainer_string) class_labels = list(range(9)) assert len(ner_explainer.predicted_class_names) == 8 for class_name in ner_explainer.predicted_class_names: assert class_name in class_labels
def test_token_classification_run_text_given(): ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) word_attributions = ner_explainer._run( "We visited Paris during the weekend") assert isinstance(word_attributions, dict) actual_tokens = list(word_attributions.keys()) expected_tokens = [ "[CLS]", "We", "visited", "Paris", "during", "the", "weekend", "[SEP]", ] assert actual_tokens == expected_tokens