def test_explain_regression(self, sentence, num_samples, positive_token, negative_token): """Tests explaining text classifiers with various output dimensions.""" def _predict_fn(sentences): """Mock prediction function.""" rs = np.random.RandomState(seed=0) predictions = [] for sentence in sentences: output = rs.uniform(-2., 2.) # To check if LIME finds the right positive/negative correlations. if negative_token in sentence: output -= rs.uniform(0., 2.) if positive_token in sentence: output += rs.uniform(0., 2.) predictions.append(output) predictions = np.stack(predictions, axis=0) return predictions explanation = lime.explain(sentence, _predict_fn, num_samples=num_samples, tokenizer=str.split) self.assertLen(explanation.feature_importance, len(sentence.split())) # The positive word should have the highest attribution score. positive_token_idx = sentence.split().index(positive_token) self.assertEqual(positive_token_idx, np.argmax(explanation.feature_importance)) # The negative word should have the lowest attribution score. negative_token_idx = sentence.split().index(negative_token) self.assertEqual(negative_token_idx, np.argmin(explanation.feature_importance))
def test_explain_returns_explanation_with_intercept(self): """Tests if the explanation contains an intercept value.""" def _predict_fn(sentences): return np.random.uniform(0., 1., [len(list(sentences)), 2]) explanation = lime.explain('Test sentence', _predict_fn, 1, num_samples=5) self.assertNotEqual(explanation.intercept, 0.)
def test_explain_returns_explanation_with_prediction(self): """Tests if the explanation contains a prediction.""" def _predict_fn(sentences): return np.random.uniform(0., 1., [len(list(sentences)), 2]) explanation = lime.explain('Test sentence', _predict_fn, class_to_explain=1, num_samples=5, return_prediction=True) self.assertIsNotNone(explanation.prediction)
def test_explain(self, sentence, num_samples, positive_token, negative_token, num_classes, class_to_explain): """Tests explaining text classifiers with various output dimensions.""" def _predict_fn(sentences): """Mock prediction function.""" rs = np.random.RandomState(seed=0) predictions = [] for sentence in sentences: probs = rs.uniform(0., 1., num_classes) # To check if LIME finds the right positive/negative correlations. if negative_token in sentence: probs[class_to_explain] = probs[class_to_explain] - 1. if positive_token in sentence: probs[class_to_explain] = probs[class_to_explain] + 1. predictions.append(probs) predictions = np.stack(predictions, axis=0) if num_classes == 1: return np.squeeze(special.expit(predictions), -1) else: return special.softmax(predictions, axis=-1) explanation = lime.explain(sentence, _predict_fn, class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split) self.assertLen(explanation.feature_importance, len(sentence.split())) # The positive word should have the highest attribution score. positive_token_idx = sentence.split().index(positive_token) self.assertEqual(positive_token_idx, np.argmax(explanation.feature_importance)) # The negative word should have the lowest attribution score. negative_token_idx = sentence.split().index(negative_token) self.assertEqual(negative_token_idx, np.argmin(explanation.feature_importance))
def test_explain_matches_original_lime(self, sentence, num_samples, num_classes, class_to_explain): """Tests if Citrus LIME matches the original implementation.""" # Assign some weight to each token a-z. # Each token contributes positively/negatively to the prediction. rs = np.random.RandomState(seed=0) token_weights = {token: rs.normal() for token in sentence.split()} token_weights[lime.DEFAULT_MASK_TOKEN] = 0. def _predict_fn(sentences): """Mock prediction function.""" rs = np.random.RandomState(seed=0) predictions = [] for sentence in sentences: probs = rs.normal(0., 0.1, size=num_classes) # To check if LIME finds the right positive/negative correlations. for token in sentence.split(): probs[class_to_explain] += token_weights[token] predictions.append(probs) return np.stack(predictions, axis=0) # Explain the prediction using Citrus LIME. explanation = lime.explain( sentence, _predict_fn, class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split, mask_token=lime.DEFAULT_MASK_TOKEN, kernel=functools.partial( lime.exponential_kernel, kernel_width=lime.DEFAULT_KERNEL_WIDTH)) scores = explanation.feature_importance # <float32>[seq_len] scores = utils.normalize_scores(scores, make_positive=False) # Explain the prediction using original LIME. original_lime_explainer = lime_text.LimeTextExplainer( class_names=map(str, np.arange(num_classes)), mask_string=lime.DEFAULT_MASK_TOKEN, kernel_width=lime.DEFAULT_KERNEL_WIDTH, split_expression=str.split, bow=False) num_features = len(sentence.split()) original_explanation = original_lime_explainer.explain_instance( sentence, _predict_fn, labels=(class_to_explain,), num_features=num_features, num_samples=num_samples) # original_explanation.local_exp is a dict that has a key class_to_explain, # which gives a sequence of (index, score) pairs. # We convert it to an array <float32>[seq_len] with a score per position. original_scores = np.zeros(num_features) for index, score in original_explanation.local_exp[class_to_explain]: original_scores[index] = score original_scores = utils.normalize_scores( original_scores, make_positive=False) # Test that Citrus LIME and original LIME match. np.testing.assert_allclose(scores, original_scores, atol=0.01)