Exemplo n.º 1
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        # Find gradient fields to interpret
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec)
        logging.info('Found fields for gradient attribution: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))
        assert len(model_outputs) == len(inputs)

        all_results = []
        for o in model_outputs:
            # Dict[field name -> interpretations]
            result = {}
            for grad_field in grad_fields:
                token_field = cast(types.TokenGradients,
                                   output_spec[grad_field]).align
                tokens = o[token_field]
                scores = self._interpret(o[grad_field], tokens)
                result[grad_field] = dtypes.SalienceMap(tokens, scores)
            all_results.append(result)

        return all_results
Exemplo n.º 2
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        if not inputs: return

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LEMON requires text inputs.')
            return None
        logging.info('Found text fields for LEMON attribution: %s',
                     str(text_keys))

        pred_key = config['pred_key']
        output_probs = np.array([output[pred_key] for output in model_outputs])

        # Explain the input given counterfactuals.

        # Dict[field name -> interpretations]
        result = {}

        # Explain each text segment in the input, keeping the others constant.
        for text_key in text_keys:
            sentences = [item[text_key] for item in inputs]
            input_to_prediction = dict(zip(sentences, output_probs))

            input_string = sentences[0]
            counterfactuals = sentences[1:]

            # Remove duplicate counterfactuals.
            counterfactuals = list(set(counterfactuals))

            logging.info('Explaining: %s', input_string)

            predict_proba = make_predict_fn(input_to_prediction)

            # Perturbs the input string, gets model predictions, fits linear model.
            explanation = lemon.explain(
                input_string,
                counterfactuals,
                predict_proba,
                class_to_explain=config['class_to_explain'],
                lowercase_tokens=config['lowercase_tokens'])

            scores = np.array(explanation.feature_importance)

            # Normalize feature values.
            scores = citrus_utils.normalize_scores(scores)

            result[text_key] = dtypes.SalienceMap(input_string.split(), scores)

        return [result]
Exemplo n.º 3
0
    def run(
            self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None,
            kernel_width: int = 25,  # TODO(lit-dev): make configurable in UI.
            mask_string:
        str = '[MASK]',  # TODO(lit-dev): make configurable in UI.
            num_samples: int = 256,  # TODO(lit-dev): make configurable in UI.
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(model.output_spec(),
                                         types.MulticlassPreds)
        if not pred_keys:
            logging.warning(
                'LIME did not find a multi-class predictions field.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key])
        label_names = pred_spec.vocab

        # Create a LIME text explainer instance.
        explainer = lime_text.LimeTextExplainer(
            class_names=label_names,
            split_expression=str.split,
            kernel_width=kernel_width,
            mask_string=mask_string,  # This is the string used to mask words.
            bow=False
        )  # bow=False masks inputs, instead of deleting them entirely.

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                logging.info('Explaining: %s', input_string)

                # Use the number of words as the number of features.
                num_features = len(input_string.split())

                def _predict_proba(strings: List[Text]):
                    """Given raw strings, return probabilities. Used by `explainer`."""
                    input_examples = [
                        new_example(input_, text_key, s) for s in strings
                    ]
                    model_outputs = model.predict(input_examples)
                    probs = np.array(
                        [output[pred_key] for output in model_outputs])
                    return probs  # <float32>[len(strings), num_labels]

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = explainer.explain_instance(
                    input_string,
                    _predict_proba,
                    num_features=num_features,
                    num_samples=num_samples)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation_to_array(explanation)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
Exemplo n.º 4
0
    def get_salience_result(self, model_input: JsonDict,
                            model: lit_model.Model, model_output: JsonDict,
                            grad_fields: List[Text]):
        result = {}

        output_spec = model.output_spec()
        # We ensure that the embedding and gradient class fields are present in the
        # model's input spec in find_fields().
        embeddings_fields = [
            cast(types.TokenGradients, output_spec[grad_field]).grad_for
            for grad_field in grad_fields
        ]

        # The gradient class input is used to specify the target class of the
        # gradient calculation (if unspecified, this option defaults to the argmax,
        # which could flip between interpolated inputs).
        grad_class_key = cast(types.TokenGradients,
                              output_spec[grad_fields[0]]).grad_target
        # TODO(b/168042999): Add option to specify the class to explain in the UI.
        grad_class = model_output[grad_class_key]

        interpolated_inputs = {}
        all_embeddings = []
        all_baselines = []
        for embed_field in embeddings_fields:
            # <float32>[num_tokens, emb_size]
            embeddings = np.array(model_output[embed_field])
            all_embeddings.append(embeddings)

            # Starts with baseline of zeros. <float32>[num_tokens, emb_size]
            baseline = self.get_baseline(embeddings)
            all_baselines.append(baseline)

            # Get interpolated inputs from baseline to original embedding.
            # <float32>[interpolation_steps, num_tokens, emb_size]
            interpolated_inputs[embed_field] = self.get_interpolated_inputs(
                baseline, embeddings, self.interpolation_steps)

        # Create model inputs and populate embedding field(s).
        inputs_with_embeds = []
        for i in range(self.interpolation_steps):
            input_copy = model_input.copy()
            # Interpolates embeddings for all inputs simultaneously.
            for embed_field in embeddings_fields:
                # <float32>[num_tokens, emb_size]
                input_copy[embed_field] = interpolated_inputs[embed_field][i]
                input_copy[grad_class_key] = grad_class

            inputs_with_embeds.append(input_copy)
        embed_outputs = model.predict(inputs_with_embeds)

        # Create list with concatenated gradients for each interpolate input.
        gradients = []
        for o in embed_outputs:
            # <float32>[total_num_tokens, emb_size]
            interp_gradients = np.concatenate(
                [o[field] for field in grad_fields])
            gradients.append(interp_gradients)
        # <float32>[interpolation_steps, total_num_tokens, emb_size]
        path_gradients = np.stack(gradients, axis=0)

        # Calculate integral
        # <float32>[total_num_tokens, emb_size]
        integral = self.estimate_integral(path_gradients)

        # <float32>[total_num_tokens, emb_size]
        concat_embeddings = np.concatenate(all_embeddings)

        # <float32>[total_num_tokens, emb_size]
        concat_baseline = np.concatenate(all_baselines)

        # <float32>[total_num_tokens, emb_size]
        integrated_gradients = integral * (np.array(concat_embeddings) -
                                           np.array(concat_baseline))
        # Dot product of integral values and (embeddings - baseline).
        # <float32>[total_num_tokens]
        attributions = np.sum(integrated_gradients, axis=-1)

        # TODO(b/168042999): Make normalization customizable in the UI.
        # <float32>[total_num_tokens]
        scores = citrus_utils.normalize_scores(attributions)

        for grad_field in grad_fields:
            # Format as salience map result.
            token_field = cast(types.TokenGradients,
                               output_spec[grad_field]).align
            tokens = model_output[token_field]

            # Only use the scores that correspond to the tokens in this grad_field.
            # The gradients for all input embeddings were concatenated in the order
            # of the grad fields, so they can be sliced out in the same order.
            sliced_scores = scores[:len(
                tokens)]  # <float32>[num_tokens in field]
            scores = scores[len(tokens):]  # <float32>[num_remaining_tokens]

            assert len(tokens) == len(sliced_scores)
            result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores)
        return result
Exemplo n.º 5
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT
        kernel_width = int(
            config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT
        mask_string = config[MASK_KEY] if config else MASK_DEFAULT
        num_samples = int(
            config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT
        seed = config[SEED_KEY] if config else SEED_DEFAULT

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(_predict_fn,
                                           model=model,
                                           original_example=input_,
                                           pred_key=pred_key)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=
                    class_to_explain,  # Index of the class to explain.
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results