def test_explain_returns_explanation_with_intercept(self): """Tests if the explanation contains an intercept value.""" def _predict_fn(sentences): return np.random.uniform(0., 1., [len(list(sentences)), 2]) explanation = limsse.explain('Test sentence', _predict_fn, 1, num_samples=5) self.assertNotEqual(explanation.intercept, 0.)
def test_explain_returns_explanation_with_prediction(self): """Tests if the explanation contains a prediction.""" def _predict_fn(sentences): return np.random.uniform(0., 1., [len(list(sentences)), 2]) explanation = limsse.explain('Test sentence', _predict_fn, class_to_explain=1, num_samples=5, return_prediction=True) self.assertIsNotNone(explanation.prediction)
def test_explain(self, sentence, num_samples, positive_token, negative_token, ngram_min_length, ngram_max_length, num_classes, class_to_explain): """Tests explaining a binary classifier with scalar output.""" def _predict_fn(sentences): """Mock prediction function.""" predictions = [] for sentence in sentences: probs = np.random.uniform(0., 1., num_classes) # To check if LIMSSE finds the right positive/negative correlations. if negative_token in sentence: probs[class_to_explain] = probs[class_to_explain] - 1. if positive_token in sentence: probs[class_to_explain] = probs[class_to_explain] + 1. predictions.append(probs) predictions = np.stack(predictions, axis=0) if num_classes == 1: predictions = np.squeeze(special.expit(predictions), -1) else: predictions = special.softmax(predictions, axis=-1) return predictions explanation = limsse.explain(sentence, _predict_fn, class_to_explain, ngram_min_length=ngram_min_length, ngram_max_length=ngram_max_length, num_samples=num_samples, tokenizer=str.split) self.assertLen(explanation.feature_importance, len(sentence.split())) # The positive word should have the highest attribution score. positive_token_idx = sentence.split().index(positive_token) self.assertEqual(positive_token_idx, np.argmax(explanation.feature_importance)) # The negative word should have the lowest attribution score. negative_token_idx = sentence.split().index(negative_token) self.assertEqual(negative_token_idx, np.argmin(explanation.feature_importance))