Пример #1
0
    def test_explain_regression(self, sentence, num_samples, positive_token,
                                negative_token):
        """Tests explaining text classifiers with various output dimensions."""
        def _predict_fn(sentences):
            """Mock prediction function."""
            rs = np.random.RandomState(seed=0)
            predictions = []
            for sentence in sentences:
                output = rs.uniform(-2., 2.)
                # To check if LIME finds the right positive/negative correlations.
                if negative_token in sentence:
                    output -= rs.uniform(0., 2.)
                if positive_token in sentence:
                    output += rs.uniform(0., 2.)
                predictions.append(output)

            predictions = np.stack(predictions, axis=0)
            return predictions

        explanation = lime.explain(sentence,
                                   _predict_fn,
                                   num_samples=num_samples,
                                   tokenizer=str.split)

        self.assertLen(explanation.feature_importance, len(sentence.split()))

        # The positive word should have the highest attribution score.
        positive_token_idx = sentence.split().index(positive_token)
        self.assertEqual(positive_token_idx,
                         np.argmax(explanation.feature_importance))

        # The negative word should have the lowest attribution score.
        negative_token_idx = sentence.split().index(negative_token)
        self.assertEqual(negative_token_idx,
                         np.argmin(explanation.feature_importance))
Пример #2
0
    def test_explain_returns_explanation_with_intercept(self):
        """Tests if the explanation contains an intercept value."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = lime.explain('Test sentence',
                                   _predict_fn,
                                   1,
                                   num_samples=5)
        self.assertNotEqual(explanation.intercept, 0.)
Пример #3
0
    def test_explain_returns_explanation_with_prediction(self):
        """Tests if the explanation contains a prediction."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = lime.explain('Test sentence',
                                   _predict_fn,
                                   class_to_explain=1,
                                   num_samples=5,
                                   return_prediction=True)
        self.assertIsNotNone(explanation.prediction)
Пример #4
0
    def test_explain(self, sentence, num_samples, positive_token,
                     negative_token, num_classes, class_to_explain):
        """Tests explaining text classifiers with various output dimensions."""
        def _predict_fn(sentences):
            """Mock prediction function."""
            rs = np.random.RandomState(seed=0)
            predictions = []
            for sentence in sentences:
                probs = rs.uniform(0., 1., num_classes)
                # To check if LIME finds the right positive/negative correlations.
                if negative_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] - 1.
                if positive_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] + 1.
                predictions.append(probs)

            predictions = np.stack(predictions, axis=0)
            if num_classes == 1:
                return np.squeeze(special.expit(predictions), -1)
            else:
                return special.softmax(predictions, axis=-1)

        explanation = lime.explain(sentence,
                                   _predict_fn,
                                   class_to_explain=class_to_explain,
                                   num_samples=num_samples,
                                   tokenizer=str.split)

        self.assertLen(explanation.feature_importance, len(sentence.split()))

        # The positive word should have the highest attribution score.
        positive_token_idx = sentence.split().index(positive_token)
        self.assertEqual(positive_token_idx,
                         np.argmax(explanation.feature_importance))

        # The negative word should have the lowest attribution score.
        negative_token_idx = sentence.split().index(negative_token)
        self.assertEqual(negative_token_idx,
                         np.argmin(explanation.feature_importance))
  def test_explain_matches_original_lime(self, sentence, num_samples,
                                         num_classes, class_to_explain):
    """Tests if Citrus LIME matches the original implementation."""

    # Assign some weight to each token a-z.
    # Each token contributes positively/negatively to the prediction.
    rs = np.random.RandomState(seed=0)
    token_weights = {token: rs.normal() for token in sentence.split()}
    token_weights[lime.DEFAULT_MASK_TOKEN] = 0.

    def _predict_fn(sentences):
      """Mock prediction function."""
      rs = np.random.RandomState(seed=0)
      predictions = []
      for sentence in sentences:
        probs = rs.normal(0., 0.1, size=num_classes)
        # To check if LIME finds the right positive/negative correlations.
        for token in sentence.split():
          probs[class_to_explain] += token_weights[token]
        predictions.append(probs)
      return np.stack(predictions, axis=0)

    # Explain the prediction using Citrus LIME.
    explanation = lime.explain(
        sentence,
        _predict_fn,
        class_to_explain=class_to_explain,
        num_samples=num_samples,
        tokenizer=str.split,
        mask_token=lime.DEFAULT_MASK_TOKEN,
        kernel=functools.partial(
            lime.exponential_kernel, kernel_width=lime.DEFAULT_KERNEL_WIDTH))
    scores = explanation.feature_importance  # <float32>[seq_len]
    scores = utils.normalize_scores(scores, make_positive=False)

    # Explain the prediction using original LIME.
    original_lime_explainer = lime_text.LimeTextExplainer(
        class_names=map(str, np.arange(num_classes)),
        mask_string=lime.DEFAULT_MASK_TOKEN,
        kernel_width=lime.DEFAULT_KERNEL_WIDTH,
        split_expression=str.split,
        bow=False)
    num_features = len(sentence.split())
    original_explanation = original_lime_explainer.explain_instance(
        sentence,
        _predict_fn,
        labels=(class_to_explain,),
        num_features=num_features,
        num_samples=num_samples)

    # original_explanation.local_exp is a dict that has a key class_to_explain,
    # which gives a sequence of (index, score) pairs.
    # We convert it to an array <float32>[seq_len] with a score per position.
    original_scores = np.zeros(num_features)
    for index, score in original_explanation.local_exp[class_to_explain]:
      original_scores[index] = score
    original_scores = utils.normalize_scores(
        original_scores, make_positive=False)

    # Test that Citrus LIME and original LIME match.
    np.testing.assert_allclose(scores, original_scores, atol=0.01)