예제 #1
0
파일: lime_test.py 프로젝트: oceanfly/lit
    def test_explain_regression(self, sentence, num_samples, positive_token,
                                negative_token):
        """Tests explaining text classifiers with various output dimensions."""
        def _predict_fn(sentences):
            """Mock prediction function."""
            rs = np.random.RandomState(seed=0)
            predictions = []
            for sentence in sentences:
                output = rs.uniform(-2., 2.)
                # To check if LIME finds the right positive/negative correlations.
                if negative_token in sentence:
                    output -= rs.uniform(0., 2.)
                if positive_token in sentence:
                    output += rs.uniform(0., 2.)
                predictions.append(output)

            predictions = np.stack(predictions, axis=0)
            return predictions

        explanation = lime.explain(sentence,
                                   _predict_fn,
                                   num_samples=num_samples,
                                   tokenizer=str.split)

        self.assertLen(explanation.feature_importance, len(sentence.split()))

        # The positive word should have the highest attribution score.
        positive_token_idx = sentence.split().index(positive_token)
        self.assertEqual(positive_token_idx,
                         np.argmax(explanation.feature_importance))

        # The negative word should have the lowest attribution score.
        negative_token_idx = sentence.split().index(negative_token)
        self.assertEqual(negative_token_idx,
                         np.argmin(explanation.feature_importance))
예제 #2
0
파일: lime_test.py 프로젝트: oceanfly/lit
    def test_explain_returns_explanation_with_intercept(self):
        """Tests if the explanation contains an intercept value."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = lime.explain('Test sentence',
                                   _predict_fn,
                                   1,
                                   num_samples=5)
        self.assertNotEqual(explanation.intercept, 0.)
예제 #3
0
파일: lime_test.py 프로젝트: oceanfly/lit
    def test_explain_returns_explanation_with_prediction(self):
        """Tests if the explanation contains a prediction."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = lime.explain('Test sentence',
                                   _predict_fn,
                                   class_to_explain=1,
                                   num_samples=5,
                                   return_prediction=True)
        self.assertIsNotNone(explanation.prediction)
예제 #4
0
파일: lime_test.py 프로젝트: oceanfly/lit
    def test_explain(self, sentence, num_samples, positive_token,
                     negative_token, num_classes, class_to_explain):
        """Tests explaining text classifiers with various output dimensions."""
        def _predict_fn(sentences):
            """Mock prediction function."""
            rs = np.random.RandomState(seed=0)
            predictions = []
            for sentence in sentences:
                probs = rs.uniform(0., 1., num_classes)
                # To check if LIME finds the right positive/negative correlations.
                if negative_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] - 1.
                if positive_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] + 1.
                predictions.append(probs)

            predictions = np.stack(predictions, axis=0)
            if num_classes == 1:
                return np.squeeze(special.expit(predictions), -1)
            else:
                return special.softmax(predictions, axis=-1)

        explanation = lime.explain(sentence,
                                   _predict_fn,
                                   class_to_explain=class_to_explain,
                                   num_samples=num_samples,
                                   tokenizer=str.split)

        self.assertLen(explanation.feature_importance, len(sentence.split()))

        # The positive word should have the highest attribution score.
        positive_token_idx = sentence.split().index(positive_token)
        self.assertEqual(positive_token_idx,
                         np.argmax(explanation.feature_importance))

        # The negative word should have the lowest attribution score.
        negative_token_idx = sentence.split().index(negative_token)
        self.assertEqual(negative_token_idx,
                         np.argmin(explanation.feature_importance))
예제 #5
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT
        kernel_width = int(
            config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT
        mask_string = config[MASK_KEY] if config else MASK_DEFAULT
        num_samples = int(
            config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT
        seed = config[SEED_KEY] if config else SEED_DEFAULT

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(_predict_fn,
                                           model=model,
                                           original_example=input_,
                                           pred_key=pred_key)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=
                    class_to_explain,  # Index of the class to explain.
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
예제 #6
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        config_defaults = {k: v.default for k, v in self.config_spec().items()}
        config = dict(config_defaults, **(config or {}))  # update and return

        provided_class_to_explain = int(config[CLASS_KEY])
        kernel_width = int(config[KERNEL_WIDTH_KEY])
        num_samples = int(config[NUM_SAMPLES_KEY])
        mask_string = (config[MASK_KEY])
        # pylint: disable=g-explicit-bool-comparison
        seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None
        # pylint: enable=g-explicit-bool-comparison

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(), (types.MulticlassPreds, types.RegressionScore,
                                  types.SparseMultilabelPreds))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None
        pred_key = config[TARGET_HEAD_KEY] or pred_keys[0]

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(
                _predict_fn,
                model=model,
                original_example=input_,
                pred_key=pred_key,
                pred_type_info=model.output_spec()[pred_key])

            class_to_explain = get_class_to_explain(provided_class_to_explain,
                                                    model, pred_key, input_)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=class_to_explain,
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.TokenSalience(
                    input_string.split(), scores)

            all_results.append(result)

        return all_results
예제 #7
0
파일: lime_test.py 프로젝트: oceanfly/lit
    def test_explain_matches_original_lime(self, sentence, num_samples,
                                           num_classes, class_to_explain):
        """Tests if Citrus LIME matches the original implementation."""
        list('abcdefghijklmnopqrstuvwxyz')
        # Assign some weight to each token a-z.
        # Each token contributes positively/negatively to the prediction.
        rs = np.random.RandomState(seed=0)
        token_weights = {token: rs.normal() for token in sentence.split()}
        token_weights[lime.DEFAULT_MASK_TOKEN] = 0.

        def _predict_fn(sentences):
            """Mock prediction function."""
            rs = np.random.RandomState(seed=0)
            predictions = []
            for sentence in sentences:
                probs = rs.normal(0., 0.1, size=num_classes)
                # To check if LIME finds the right positive/negative correlations.
                for token in sentence.split():
                    probs[class_to_explain] += token_weights[token]
                predictions.append(probs)
            return np.stack(predictions, axis=0)

        # Explain the prediction using Citrus LIME.
        explanation = lime.explain(sentence,
                                   _predict_fn,
                                   class_to_explain=class_to_explain,
                                   num_samples=num_samples,
                                   tokenizer=str.split,
                                   mask_token=lime.DEFAULT_MASK_TOKEN,
                                   kernel=functools.partial(
                                       lime.exponential_kernel,
                                       kernel_width=lime.DEFAULT_KERNEL_WIDTH))
        scores = explanation.feature_importance  # <float32>[seq_len]
        scores = utils.normalize_scores(scores, make_positive=False)

        # Explain the prediction using original LIME.
        original_lime_explainer = original_lime.LimeTextExplainer(
            class_names=map(str, np.arange(num_classes)),
            mask_string=lime.DEFAULT_MASK_TOKEN,
            kernel_width=lime.DEFAULT_KERNEL_WIDTH,
            split_expression=str.split,
            bow=False)
        num_features = len(sentence.split())
        original_explanation = original_lime_explainer.explain_instance(
            sentence,
            _predict_fn,
            labels=(class_to_explain, ),
            num_features=num_features,
            num_samples=num_samples)

        # original_explanation.local_exp is a dict that has a key class_to_explain,
        # which gives a sequence of (index, score) pairs.
        # We convert it to an array <float32>[seq_len] with a score per position.
        original_scores = np.zeros(num_features)
        for index, score in original_explanation.local_exp[class_to_explain]:
            original_scores[index] = score
        original_scores = utils.normalize_scores(original_scores,
                                                 make_positive=False)

        # Test that Citrus LIME and original LIME match.
        np.testing.assert_allclose(scores, original_scores, atol=0.01)