示例#1
0
    def test_explain_returns_explanation_with_intercept(self):
        """Tests if the explanation contains an intercept value."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = limsse.explain('Test sentence',
                                     _predict_fn,
                                     1,
                                     num_samples=5)
        self.assertNotEqual(explanation.intercept, 0.)
示例#2
0
    def test_explain_returns_explanation_with_prediction(self):
        """Tests if the explanation contains a prediction."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = limsse.explain('Test sentence',
                                     _predict_fn,
                                     class_to_explain=1,
                                     num_samples=5,
                                     return_prediction=True)
        self.assertIsNotNone(explanation.prediction)
示例#3
0
    def test_explain(self, sentence, num_samples, positive_token,
                     negative_token, ngram_min_length, ngram_max_length,
                     num_classes, class_to_explain):
        """Tests explaining a binary classifier with scalar output."""
        def _predict_fn(sentences):
            """Mock prediction function."""
            predictions = []
            for sentence in sentences:
                probs = np.random.uniform(0., 1., num_classes)
                # To check if LIMSSE finds the right positive/negative correlations.
                if negative_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] - 1.
                if positive_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] + 1.
                predictions.append(probs)

            predictions = np.stack(predictions, axis=0)
            if num_classes == 1:
                predictions = np.squeeze(special.expit(predictions), -1)
            else:
                predictions = special.softmax(predictions, axis=-1)
            return predictions

        explanation = limsse.explain(sentence,
                                     _predict_fn,
                                     class_to_explain,
                                     ngram_min_length=ngram_min_length,
                                     ngram_max_length=ngram_max_length,
                                     num_samples=num_samples,
                                     tokenizer=str.split)

        self.assertLen(explanation.feature_importance, len(sentence.split()))

        # The positive word should have the highest attribution score.
        positive_token_idx = sentence.split().index(positive_token)
        self.assertEqual(positive_token_idx,
                         np.argmax(explanation.feature_importance))

        # The negative word should have the lowest attribution score.
        negative_token_idx = sentence.split().index(negative_token)
        self.assertEqual(negative_token_idx,
                         np.argmin(explanation.feature_importance))