Example #1
0
 def test_normalize_scores(self, scores, make_positive):
     """Check if the scores sum to 1 after taking their absolute values."""
     original_min = np.min(scores)
     scores = utils.normalize_scores(scores, make_positive=make_positive)
     self.assertAllClose(1.0, np.abs(scores).sum(-1))
     if not make_positive:  # Keep the sign of originally negative values.
         self.assertLessEqual(np.min(scores), np.max([0.0, original_min]))
Example #2
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        if not inputs: return

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LEMON requires text inputs.')
            return None
        logging.info('Found text fields for LEMON attribution: %s',
                     str(text_keys))

        pred_key = config['pred_key']
        output_probs = np.array([output[pred_key] for output in model_outputs])

        # Explain the input given counterfactuals.

        # Dict[field name -> interpretations]
        result = {}

        # Explain each text segment in the input, keeping the others constant.
        for text_key in text_keys:
            sentences = [item[text_key] for item in inputs]
            input_to_prediction = dict(zip(sentences, output_probs))

            input_string = sentences[0]
            counterfactuals = sentences[1:]

            # Remove duplicate counterfactuals.
            counterfactuals = list(set(counterfactuals))

            logging.info('Explaining: %s', input_string)

            predict_proba = make_predict_fn(input_to_prediction)

            # Perturbs the input string, gets model predictions, fits linear model.
            explanation = lemon.explain(
                input_string,
                counterfactuals,
                predict_proba,
                class_to_explain=config['class_to_explain'],
                lowercase_tokens=config['lowercase_tokens'])

            scores = np.array(explanation.feature_importance)

            # Normalize feature values.
            scores = citrus_utils.normalize_scores(scores)

            result[text_key] = dtypes.TokenSalience(input_string.split(),
                                                    scores)

        return [result]
Example #3
0
    def _interpret(self, grads: np.ndarray, embs: np.ndarray):
        assert grads.shape == embs.shape

        # dot product of gradients and embeddings
        # <float32>[num_tokens]
        grad_dot_input = np.sum(grads * embs, axis=-1)
        scores = citrus_utils.normalize_scores(grad_dot_input)
        return scores
Example #4
0
    def get_salience_result(self, model_input: JsonDict,
                            model: lit_model.Model, model_output: JsonDict,
                            grad_fields: List[Text]):
        result = {}

        output_spec = model.output_spec()
        # We ensure that the embedding and gradient class fields are present in the
        # model's input spec in find_fields().
        embeddings_fields = [
            cast(types.TokenGradients, output_spec[grad_field]).grad_for
            for grad_field in grad_fields
        ]

        # The gradient class input is used to specify the target class of the
        # gradient calculation (if unspecified, this option defaults to the argmax,
        # which could flip between interpolated inputs).
        grad_class_key = cast(types.TokenGradients,
                              output_spec[grad_fields[0]]).grad_target
        # TODO(b/168042999): Add option to specify the class to explain in the UI.
        grad_class = model_output[grad_class_key]

        interpolated_inputs = {}
        all_embeddings = []
        all_baselines = []
        for embed_field in embeddings_fields:
            # <float32>[num_tokens, emb_size]
            embeddings = np.array(model_output[embed_field])
            all_embeddings.append(embeddings)

            # Starts with baseline of zeros. <float32>[num_tokens, emb_size]
            baseline = self.get_baseline(embeddings)
            all_baselines.append(baseline)

            # Get interpolated inputs from baseline to original embedding.
            # <float32>[interpolation_steps, num_tokens, emb_size]
            interpolated_inputs[embed_field] = self.get_interpolated_inputs(
                baseline, embeddings, self.interpolation_steps)

        # Create model inputs and populate embedding field(s).
        inputs_with_embeds = []
        for i in range(self.interpolation_steps):
            input_copy = model_input.copy()
            # Interpolates embeddings for all inputs simultaneously.
            for embed_field in embeddings_fields:
                # <float32>[num_tokens, emb_size]
                input_copy[embed_field] = interpolated_inputs[embed_field][i]
                input_copy[grad_class_key] = grad_class

            inputs_with_embeds.append(input_copy)
        embed_outputs = model.predict(inputs_with_embeds)

        # Create list with concatenated gradients for each interpolate input.
        gradients = []
        for o in embed_outputs:
            # <float32>[total_num_tokens, emb_size]
            interp_gradients = np.concatenate(
                [o[field] for field in grad_fields])
            gradients.append(interp_gradients)
        # <float32>[interpolation_steps, total_num_tokens, emb_size]
        path_gradients = np.stack(gradients, axis=0)

        # Calculate integral
        # <float32>[total_num_tokens, emb_size]
        integral = self.estimate_integral(path_gradients)

        # <float32>[total_num_tokens, emb_size]
        concat_embeddings = np.concatenate(all_embeddings)

        # <float32>[total_num_tokens, emb_size]
        concat_baseline = np.concatenate(all_baselines)

        # <float32>[total_num_tokens, emb_size]
        integrated_gradients = integral * (np.array(concat_embeddings) -
                                           np.array(concat_baseline))
        # Dot product of integral values and (embeddings - baseline).
        # <float32>[total_num_tokens]
        attributions = np.sum(integrated_gradients, axis=-1)

        # TODO(b/168042999): Make normalization customizable in the UI.
        # <float32>[total_num_tokens]
        scores = citrus_utils.normalize_scores(attributions)

        for grad_field in grad_fields:
            # Format as salience map result.
            token_field = cast(types.TokenGradients,
                               output_spec[grad_field]).align
            tokens = model_output[token_field]

            # Only use the scores that correspond to the tokens in this grad_field.
            # The gradients for all input embeddings were concatenated in the order
            # of the grad fields, so they can be sliced out in the same order.
            sliced_scores = scores[:len(
                tokens)]  # <float32>[num_tokens in field]
            scores = scores[len(tokens):]  # <float32>[num_remaining_tokens]

            assert len(tokens) == len(sliced_scores)
            result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores)
        return result
Example #5
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT
        kernel_width = int(
            config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT
        mask_string = config[MASK_KEY] if config else MASK_DEFAULT
        num_samples = int(
            config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT
        seed = config[SEED_KEY] if config else SEED_DEFAULT

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(_predict_fn,
                                           model=model,
                                           original_example=input_,
                                           pred_key=pred_key)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=
                    class_to_explain,  # Index of the class to explain.
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
Example #6
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        config_defaults = {k: v.default for k, v in self.config_spec().items()}
        config = dict(config_defaults, **(config or {}))  # update and return

        provided_class_to_explain = int(config[CLASS_KEY])
        kernel_width = int(config[KERNEL_WIDTH_KEY])
        num_samples = int(config[NUM_SAMPLES_KEY])
        mask_string = (config[MASK_KEY])
        # pylint: disable=g-explicit-bool-comparison
        seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None
        # pylint: enable=g-explicit-bool-comparison

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(), (types.MulticlassPreds, types.RegressionScore,
                                  types.SparseMultilabelPreds))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None
        pred_key = config[TARGET_HEAD_KEY] or pred_keys[0]

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(
                _predict_fn,
                model=model,
                original_example=input_,
                pred_key=pred_key,
                pred_type_info=model.output_spec()[pred_key])

            class_to_explain = get_class_to_explain(provided_class_to_explain,
                                                    model, pred_key, input_)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=class_to_explain,
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.TokenSalience(
                    input_string.split(), scores)

            all_results.append(result)

        return all_results
Example #7
0
    def test_explain_matches_original_lime(self, sentence, num_samples,
                                           num_classes, class_to_explain):
        """Tests if Citrus LIME matches the original implementation."""
        list('abcdefghijklmnopqrstuvwxyz')
        # Assign some weight to each token a-z.
        # Each token contributes positively/negatively to the prediction.
        rs = np.random.RandomState(seed=0)
        token_weights = {token: rs.normal() for token in sentence.split()}
        token_weights[lime.DEFAULT_MASK_TOKEN] = 0.

        def _predict_fn(sentences):
            """Mock prediction function."""
            rs = np.random.RandomState(seed=0)
            predictions = []
            for sentence in sentences:
                probs = rs.normal(0., 0.1, size=num_classes)
                # To check if LIME finds the right positive/negative correlations.
                for token in sentence.split():
                    probs[class_to_explain] += token_weights[token]
                predictions.append(probs)
            return np.stack(predictions, axis=0)

        # Explain the prediction using Citrus LIME.
        explanation = lime.explain(sentence,
                                   _predict_fn,
                                   class_to_explain=class_to_explain,
                                   num_samples=num_samples,
                                   tokenizer=str.split,
                                   mask_token=lime.DEFAULT_MASK_TOKEN,
                                   kernel=functools.partial(
                                       lime.exponential_kernel,
                                       kernel_width=lime.DEFAULT_KERNEL_WIDTH))
        scores = explanation.feature_importance  # <float32>[seq_len]
        scores = utils.normalize_scores(scores, make_positive=False)

        # Explain the prediction using original LIME.
        original_lime_explainer = original_lime.LimeTextExplainer(
            class_names=map(str, np.arange(num_classes)),
            mask_string=lime.DEFAULT_MASK_TOKEN,
            kernel_width=lime.DEFAULT_KERNEL_WIDTH,
            split_expression=str.split,
            bow=False)
        num_features = len(sentence.split())
        original_explanation = original_lime_explainer.explain_instance(
            sentence,
            _predict_fn,
            labels=(class_to_explain, ),
            num_features=num_features,
            num_samples=num_samples)

        # original_explanation.local_exp is a dict that has a key class_to_explain,
        # which gives a sequence of (index, score) pairs.
        # We convert it to an array <float32>[seq_len] with a score per position.
        original_scores = np.zeros(num_features)
        for index, score in original_explanation.local_exp[class_to_explain]:
            original_scores[index] = score
        original_scores = utils.normalize_scores(original_scores,
                                                 make_positive=False)

        # Test that Citrus LIME and original LIME match.
        np.testing.assert_allclose(scores, original_scores, atol=0.01)