Exemple #1
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        # Find gradient fields to interpret
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec)
        logging.info('Found fields for gradient attribution: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))
        assert len(model_outputs) == len(inputs)

        all_results = []
        for o in model_outputs:
            # Dict[field name -> interpretations]
            result = {}
            for grad_field in grad_fields:
                token_field = cast(types.TokenGradients,
                                   output_spec[grad_field]).align
                tokens = o[token_field]
                scores = self._interpret(o[grad_field], tokens)
                result[grad_field] = dtypes.SalienceMap(tokens, scores)
            all_results.append(result)

        return all_results
Exemple #2
0
def validate_t5_model(model: lit_model.Model) -> lit_model.Model:
    """Validate that a given model looks like a T5 model.

  This checks the model spec at runtime; it is intended to be used before server
  start, such as in the __init__() method of a wrapper class.

  Args:
    model: a LIT model

  Returns:
    model: the same model

  Raises:
    AssertionError: if the model's spec does not match that expected for a T5
    model.
  """
    # Check inputs
    ispec = model.input_spec()
    assert "input_text" in ispec
    assert isinstance(ispec["input_text"], lit_types.TextSegment)
    if "target_text" in ispec:
        assert isinstance(ispec["target_text"], lit_types.TextSegment)

    # Check outputs
    ospec = model.output_spec()
    assert "output_text" in ospec
    assert isinstance(
        ospec["output_text"],
        (lit_types.GeneratedText, lit_types.GeneratedTextCandidates))
    assert ospec["output_text"].parent == "target_text"

    return model
Exemple #3
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        # Find gradient fields to interpret
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(input_spec, output_spec)
        logging.info('Found fields for integrated gradients: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        all_results = []
        for model_output, model_input in zip(model_outputs, inputs):
            result = self.get_salience_result(model_input, model, model_output,
                                              grad_fields)
            all_results.append(result)
        return all_results
Exemple #4
0
    def run_with_metadata(
            self,
            indexed_inputs: Sequence[IndexedInput],
            model: lit_model.Model,
            dataset: lit_dataset.IndexedDataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[JsonDict]:
        """Run this component, given a model and input(s).

    Args:
      indexed_inputs: Inputs to cluster.
      model: Model that provides salience maps.
      dataset: Dataset to compute salience maps for.
      model_outputs: Precomputed model outputs.
      config: Config for clustering and salience computation

    Returns:
      Dict with 2 keys:
        `CLUSTER_ID_KEY`: Contains the cluster assignments. One cluster id per
          dataset example.
        `REPRESENTATION_KEY`: Contains the representations of all examples in
          the dataset that were used in the clustering.
    """
        config = config or {}
        # Find gradient fields to interpret
        grad_fields = self.find_fields(model.output_spec())
        token_saliencies = self.salience_mappers[
            config['salience_mapper']].run_with_metadata(
                indexed_inputs, model, dataset, model_outputs, config)

        if not token_saliencies:
            return None

        vocab = self._build_vocab(token_saliencies)
        representations = self._compute_fixed_length_representation(
            token_saliencies, vocab)

        cluster_ids = {}
        grad_field_to_representations = {}

        for grad_field in grad_fields:
            weight_matrix = np.vstack(representation[grad_field]
                                      for representation in representations)
            self.kmeans[grad_field] = cluster.KMeans(n_clusters=config.get(
                N_CLUSTERS_KEY,
                self.config_spec()[N_CLUSTERS_KEY].default))
            cluster_ids[grad_field] = self.kmeans[grad_field].fit_predict(
                weight_matrix).tolist()
            grad_field_to_representations[grad_field] = weight_matrix

        return {
            CLUSTER_ID_KEY: cluster_ids,
            REPRESENTATION_KEY: grad_field_to_representations
        }
    def _filter_ds_examples(
            self,
            dataset: lit_dataset.IndexedDataset,
            dataset_name: Text,
            model: lit_model.Model,
            reference_output: JsonDict,
            pred_key: Text,
            regression_thresh: Optional[float] = None) -> List[JsonDict]:
        """Reads all dataset examples and returns only those that are flips."""
        if not isinstance(dataset, lit_dataset.IndexedDataset):
            raise ValueError(
                'Only indexed datasets are currently supported by the TabularMTC'
                'generator.')

        indexed_examples = list(dataset.indexed_examples)
        filtered_examples = []
        preds = model.predict_with_metadata(indexed_examples,
                                            dataset_name=dataset_name)

        # Find all DS examples that are flips with respect to the reference example.
        for indexed_example, pred in zip(indexed_examples, preds):
            flip = cf_utils.is_prediction_flip(
                cf_output=pred,
                orig_output=reference_output,
                output_spec=model.output_spec(),
                pred_key=pred_key,
                regression_thresh=regression_thresh)
            if flip:
                candidate_example = indexed_example['data'].copy()
                self._find_dataset_parent_and_set(
                    model_output_spec=model.output_spec(),
                    pred_key=pred_key,
                    dataset_spec=dataset.spec(),
                    example=candidate_example,
                    predicted_value=pred[pred_key])
                filtered_examples.append(candidate_example)
        return filtered_examples
    def _is_flip(
            self,
            model: lit_model.Model,
            cf_example: JsonDict,
            orig_output: JsonDict,
            pred_key: Text,
            regression_thresh: Optional[float] = None) -> Tuple[bool, Any]:

        cf_output = list(model.predict([cf_example]))[0]
        feature_predicted_value = cf_output[pred_key]
        return cf_utils.is_prediction_flip(
            cf_output=cf_output,
            orig_output=orig_output,
            output_spec=model.output_spec(),
            pred_key=pred_key,
            regression_thresh=regression_thresh), feature_predicted_value
Exemple #7
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        config = config or {}
        class_to_explain = config.get(CLASS_KEY,
                                      self.config_spec()[CLASS_KEY].default)
        interpolation_steps = int(
            config.get(INTERPOLATION_KEY,
                       self.config_spec()[INTERPOLATION_KEY].default))
        normalization = config.get(
            NORMALIZATION_KEY,
            self.config_spec()[NORMALIZATION_KEY].default)

        # Find gradient fields to interpret
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(input_spec, output_spec)
        logging.info('Found fields for integrated gradients: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        all_results = []
        for model_output, model_input in zip(model_outputs, inputs):
            result = self.get_salience_result(model_input, model,
                                              interpolation_steps,
                                              normalization, class_to_explain,
                                              model_output, grad_fields)
            all_results.append(result)
        return all_results
Exemple #8
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None,
                 num_examples: int = 1) -> List[JsonDict]:
        """Use gradient to find/substitute the token with largest impact on loss."""
        del dataset  # Unused.

        assert model is not None, "Please provide a model for this generator."
        logging.info(r"W3lc0m3 t0 H0tFl1p \o/")
        logging.info("Original example: %r", example)

        # Find gradient fields to use for HotFlip
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec)
        logging.info("Found gradient fields for HotFlip use: %s",
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            logging.info("No gradient fields found. Cannot use HotFlip. :-(")
            return []  # Cannot generate examples without gradients.

        # Get model outputs.
        logging.info(
            "Performing a forward/backward pass on the input example.")
        model_output = model.predict_single(example)
        logging.info(model_output.keys())

        # Get model word embeddings and vocab.
        inv_vocab, embed = model.get_embedding_table()
        assert len(
            inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch."
        logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab),
                     embed.shape)

        # Perform a flip in each sequence for which we have gradients (separately).
        # Each sequence may give rise to multiple new examples, depending on how
        # many words we flip.
        # TODO(lit-team): make configurable how many new examples are desired.
        # TODO(lit-team): use only 1 sequence as input (configurable in UI).
        new_examples = []
        for grad_field in grad_fields:

            # Get the tokens and their gradient vectors.
            token_field = output_spec[grad_field].align  # pytype: disable=attribute-error
            tokens = model_output[token_field]
            grads = model_output[grad_field]

            # Identify the token with the largest gradient norm.
            # TODO(lit-team): consider normalizing across all grad fields or just
            # across each one individually.
            grad_norm = np.linalg.norm(grads, axis=1)
            grad_norm = grad_norm / np.sum(
                grad_norm)  # Match grad attribution value.

            # Get a list of indices of input tokens, sorted by norm, highest first.
            sorted_by_grad_norm = np.argsort(grad_norm)[::-1]

            for i in range(min(num_examples, len(tokens))):
                token_id = sorted_by_grad_norm[i]
                logging.info(
                    "Selected token: %s (pos=%d) with gradient norm %f",
                    tokens[token_id], token_id, grad_norm[token_id])
                token_grad = grads[token_id]

                # Take dot product with all word embeddings. Get largest value.
                scores = np.dot(embed, token_grad)

                # TODO(lit-team): Can add criteria to the winner e.g. cosine distance.
                winner = np.argmax(scores)
                logging.info(
                    "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)",
                    tokens[token_id], token_id, i, inv_vocab[winner], winner)

                # Create a new input to the model.
                # TODO(iftenney, bastings): enforce somewhere that this field has the
                # same name in the input and output specs.
                input_token_field = token_field
                input_text_field = input_spec[input_token_field].parent  # pytype: disable=attribute-error
                new_example = copy.deepcopy(example)
                modified_tokens = copy.copy(tokens)
                modified_tokens[token_id] = inv_vocab[winner]
                new_example[input_token_field] = modified_tokens
                # TODO(iftenney, bastings): call a model-provided detokenizer here?
                # Though in general tokenization isn't invertible and it's possible for
                # HotFlip to produce wordpiece sequences that don't correspond to any
                # input string.
                new_example[input_text_field] = " ".join(modified_tokens)

                # Predict a new label for this example.
                new_output = model.predict_single(new_example)

                # Update label if multi-class prediction.
                # TODO(lit-dev): provide a general system for handling labels on
                # generated examples.
                for pred_key, pred_type in model.output_spec().items():
                    if isinstance(pred_type, types.MulticlassPreds):
                        probabilities = new_output[pred_key]
                        prediction = np.argmax(probabilities)
                        label_key = output_spec[pred_key].parent
                        label_names = output_spec[pred_key].vocab
                        new_label = label_names[prediction]
                        new_example[label_key] = new_label
                        logging.info("Updated example with new label: %s",
                                     new_label)

                new_examples.append(new_example)

        return new_examples
Exemple #9
0
    def run_with_metadata(
            self,
            indexed_inputs: Sequence[IndexedInput],
            model: lit_model.Model,
            dataset: lit_dataset.IndexedDataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Runs the TCAV method given the params in the inputs and config.

    Args:
      indexed_inputs: all examples in the dataset, in the indexed input format.
      model: the model being explained.
      dataset: the dataset which the current examples belong to.
      model_outputs: optional model outputs from calling model.predict(inputs).
      config: a config which should specify: {
          'concept_set_ids': [list of ids to use in concept set]
          'class_to_explain': [gradient class to explain],
          'grad_layer': [the Gradient field key of the layer to explain],
          'random_state': [an optional seed to make outputs deterministic]
          'dataset_name': [the name of the dataset (used for caching)]
          'test_size': [Percentage of the example set to use in the LM test set]
          'negative_set_ids': [optional list of ids to use as negative set] }

    Returns:
      A JsonDict containing the TCAV scores, directional derivatives,
      statistical test p-values, and LM accuracies.
    """
        config = TCAVConfig(**config)
        # TODO(b/171513556): get these from the Dataset object once indices are
        # available there.
        dataset_examples = indexed_inputs

        # Get this layer's output spec keys for gradients and embeddings.
        grad_layer = config.grad_layer
        output_spec = model.output_spec()
        emb_layer = cast(types.Gradients, output_spec[grad_layer]).grad_for

        # Get the class that the gradients were computed for.
        grad_class_key = cast(types.Gradients,
                              output_spec[grad_layer]).grad_target_field_key

        ids_set = set(config.concept_set_ids)
        concept_set = [ex for ex in dataset_examples if ex['id'] in ids_set]
        non_concept_set = [
            ex for ex in dataset_examples if ex['id'] not in ids_set
        ]

        # Get outputs using model.predict().
        dataset_outputs = list(
            model.predict_with_metadata(dataset_examples,
                                        dataset_name=config.dataset_name))

        if config.negative_set_ids:
            negative_ids_set = set(config.negative_set_ids)
            negative_set = [
                ex for ex in dataset_examples if ex['id'] in negative_ids_set
            ]
            return self._run_relative_tcav(grad_layer, emb_layer,
                                           grad_class_key, concept_set,
                                           negative_set, dataset_outputs,
                                           model, config)
        else:
            return self._run_default_tcav(grad_layer, emb_layer,
                                          grad_class_key, concept_set,
                                          non_concept_set, dataset_outputs,
                                          model, config)
Exemple #10
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Identify minimal sets of token albations that alter the prediction."""
        del dataset  # Unused.

        config = config or {}
        num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT))
        max_ablations = int(
            config.get(MAX_ABLATIONS_KEY, MAX_ABLATIONS_DEFAULT))
        assert model is not None, "Please provide a model for this generator."

        input_spec = model.input_spec()
        pred_key = config.get(PREDICTION_KEY, "")
        regression_thresh = float(
            config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT))

        output_spec = model.output_spec()
        assert pred_key, "Please provide the prediction key"
        assert pred_key in output_spec, "Invalid prediction key"

        is_regression = isinstance(output_spec[pred_key],
                                   types.RegressionScore)
        if not is_regression:
            assert isinstance(output_spec[pred_key], types.MulticlassPreds), (
                "Only classification or regression models are supported")
        logging.info(r"W3lc0m3 t0 Ablatl0nFl1p \o/")
        logging.info("Original example: %r", example)

        # Check for fields to ablate.
        fields_to_ablate = list(config.get(FIELDS_TO_ABLATE_KEY, []))
        if not fields_to_ablate:
            return []

        # Get model outputs.
        orig_output = list(model.predict([example]))[0]
        loo_scores = self._generate_leave_one_out_ablation_score(
            example, model, input_spec, output_spec, orig_output, pred_key,
            fields_to_ablate)

        if isinstance(output_spec[pred_key], types.RegressionScore):
            ablation_idxs_generator = self._gen_ablation_idxs(
                loo_scores, max_ablations, orig_output[pred_key],
                regression_thresh)
        else:
            ablation_idxs_generator = self._gen_ablation_idxs(
                loo_scores, max_ablations)

        tokens_map = {}
        for field in input_spec.keys():
            tokens = self._get_tokens(example, input_spec, field)
            if not tokens:
                continue
            tokens_map[field] = tokens

        successful_cfs = []
        successful_positions = []
        for ablation_idxs in ablation_idxs_generator:
            if len(successful_cfs) >= num_examples:
                return successful_cfs

            # If a subset of the set of tokens have already been successful in
            # obtaining a flip, we continue. This ensures that we only consider
            # sets of tokens that are minimal.
            if self._subset_exists(set(ablation_idxs), successful_positions):
                continue

            # Create counterfactual and obtain model prediction.
            cf = self._create_cf(example, input_spec, ablation_idxs)
            cf_output = list(model.predict([cf]))[0]

            # Check if counterfactual results in a prediction flip.
            if cf_utils.is_prediction_flip(cf_output, orig_output, output_spec,
                                           pred_key, regression_thresh):
                # Prediction flip found!
                cf_utils.update_prediction(cf, cf_output, output_spec,
                                           pred_key)
                cf[ABLATED_TOKENS_KEY] = str([
                    f"{field}[{tokens_map[field][idx]}]"
                    for field, idx in ablation_idxs
                ])
                successful_cfs.append(cf)
                successful_positions.append(set(ablation_idxs))
        return successful_cfs
Exemple #11
0
 def is_compatible(self, model: lit_model.Model) -> bool:
     input_spec = model.input_spec()
     output_spec = model.output_spec()
     return find_supported_fields(input_spec, output_spec) is not None
Exemple #12
0
    def get_salience_result(self, model_input: JsonDict,
                            model: lit_model.Model, model_output: JsonDict,
                            grad_fields: List[Text]):
        result = {}

        output_spec = model.output_spec()
        # We ensure that the embedding and gradient class fields are present in the
        # model's input spec in find_fields().
        embeddings_fields = [
            cast(types.TokenGradients, output_spec[grad_field]).grad_for
            for grad_field in grad_fields
        ]

        # The gradient class input is used to specify the target class of the
        # gradient calculation (if unspecified, this option defaults to the argmax,
        # which could flip between interpolated inputs).
        grad_class_key = cast(types.TokenGradients,
                              output_spec[grad_fields[0]]).grad_target
        # TODO(b/168042999): Add option to specify the class to explain in the UI.
        grad_class = model_output[grad_class_key]

        interpolated_inputs = {}
        all_embeddings = []
        all_baselines = []
        for embed_field in embeddings_fields:
            # <float32>[num_tokens, emb_size]
            embeddings = np.array(model_output[embed_field])
            all_embeddings.append(embeddings)

            # Starts with baseline of zeros. <float32>[num_tokens, emb_size]
            baseline = self.get_baseline(embeddings)
            all_baselines.append(baseline)

            # Get interpolated inputs from baseline to original embedding.
            # <float32>[interpolation_steps, num_tokens, emb_size]
            interpolated_inputs[embed_field] = self.get_interpolated_inputs(
                baseline, embeddings, self.interpolation_steps)

        # Create model inputs and populate embedding field(s).
        inputs_with_embeds = []
        for i in range(self.interpolation_steps):
            input_copy = model_input.copy()
            # Interpolates embeddings for all inputs simultaneously.
            for embed_field in embeddings_fields:
                # <float32>[num_tokens, emb_size]
                input_copy[embed_field] = interpolated_inputs[embed_field][i]
                input_copy[grad_class_key] = grad_class

            inputs_with_embeds.append(input_copy)
        embed_outputs = model.predict(inputs_with_embeds)

        # Create list with concatenated gradients for each interpolate input.
        gradients = []
        for o in embed_outputs:
            # <float32>[total_num_tokens, emb_size]
            interp_gradients = np.concatenate(
                [o[field] for field in grad_fields])
            gradients.append(interp_gradients)
        # <float32>[interpolation_steps, total_num_tokens, emb_size]
        path_gradients = np.stack(gradients, axis=0)

        # Calculate integral
        # <float32>[total_num_tokens, emb_size]
        integral = self.estimate_integral(path_gradients)

        # <float32>[total_num_tokens, emb_size]
        concat_embeddings = np.concatenate(all_embeddings)

        # <float32>[total_num_tokens, emb_size]
        concat_baseline = np.concatenate(all_baselines)

        # <float32>[total_num_tokens, emb_size]
        integrated_gradients = integral * (np.array(concat_embeddings) -
                                           np.array(concat_baseline))
        # Dot product of integral values and (embeddings - baseline).
        # <float32>[total_num_tokens]
        attributions = np.sum(integrated_gradients, axis=-1)

        # TODO(b/168042999): Make normalization customizable in the UI.
        # <float32>[total_num_tokens]
        scores = citrus_utils.normalize_scores(attributions)

        for grad_field in grad_fields:
            # Format as salience map result.
            token_field = cast(types.TokenGradients,
                               output_spec[grad_field]).align
            tokens = model_output[token_field]

            # Only use the scores that correspond to the tokens in this grad_field.
            # The gradients for all input embeddings were concatenated in the order
            # of the grad fields, so they can be sliced out in the same order.
            sliced_scores = scores[:len(
                tokens)]  # <float32>[num_tokens in field]
            scores = scores[len(tokens):]  # <float32>[num_remaining_tokens]

            assert len(tokens) == len(sliced_scores)
            result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores)
        return result
Exemple #13
0
 def is_compatible(self, model: lit_model.Model):
     compatible_fields = self.find_fields(model.output_spec())
     return len(compatible_fields)
Exemple #14
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT
        kernel_width = int(
            config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT
        mask_string = config[MASK_KEY] if config else MASK_DEFAULT
        num_samples = int(
            config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT
        seed = config[SEED_KEY] if config else SEED_DEFAULT

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(_predict_fn,
                                           model=model,
                                           original_example=input_,
                                           pred_key=pred_key)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=
                    class_to_explain,  # Index of the class to explain.
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
Exemple #15
0
 def find_fields(self, model: lit_model.Model) -> Dict[str, List[str]]:
   src_fields = utils.find_spec_keys(model.input_spec(), types.TextSegment)
   gen_fields = utils.find_spec_keys(
       model.output_spec(),
       (types.GeneratedText, types.GeneratedTextCandidates))
   return {f: src_fields for f in gen_fields}
    def _find_closer_flip_using_interpolation(
            self,
            ref_example: JsonDict,
            known_flip: JsonDict,
            target_pred: JsonDict,
            pred_key: Text,
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            regression_threshold: Optional[float] = None,
            max_attempts: int = 4) -> Optional[JsonDict]:
        """Looks for the decision boundary between two examples using interpolation.

    The method searches for a flip that is closer to the `target example` than
    `known_flip`. The method performs the binary search by interpolating scalar
    values.

    Args:
      ref_example: an example for which the flip is searched.
      known_flip: an example that represents a known flip.
      target_pred: the model prediction at `ref_example`.
      pred_key: the named of the field inside `target_pred` that holds the
        prediction value.
      model: model to use for running predictions.
      dataset: dataset that contains `known_flip`.
      regression_threshold: threshold to use for regression models.
      max_attempts: number of binary search attempts.

    Returns:
      The counterfactual (flip) if found; 'None' otherwise.
    """
        min_alpha = 0.0
        max_alpha = 1.0
        closest_flip = None
        input_spec = model.input_spec()
        has_scalar = False
        for _ in range(max_attempts):
            # Interpolate the scalar values using binary search.
            current_alpha = (min_alpha + max_alpha) / 2
            candidate = known_flip.copy()
            for field in ref_example:
                if (field in candidate and field in input_spec
                        and isinstance(input_spec[field], lit_types.Scalar)
                        and candidate[field] is not None
                        and ref_example[field] is not None):
                    candidate[field] = known_flip[field] * (
                        1 - current_alpha) + ref_example[field] * current_alpha
                    has_scalar = True
            # The interpolation makes sense only for scalar values. If there are no
            # scalar fields that can be interpolated then terminate the search.
            if not has_scalar:
                return None
            flip, predicted_value = self._is_flip(
                model=model,
                cf_example=candidate,
                orig_output=target_pred,
                pred_key=pred_key,
                regression_thresh=regression_threshold)
            if flip:
                self._find_dataset_parent_and_set(
                    model_output_spec=model.output_spec(),
                    pred_key=pred_key,
                    dataset_spec=dataset.spec(),
                    example=candidate,
                    predicted_value=predicted_value)
                closest_flip = candidate
                min_alpha = current_alpha
            else:
                max_alpha = current_alpha
        return closest_flip
    def _find_hot_flip(
        self,
        ref_example: JsonDict,
        ds_example: JsonDict,
        features_to_consider: List[Text],
        model: lit_model.Model,
        target_pred: JsonDict,
        pred_key: Text,
        dataset: lit_dataset.Dataset,
        interpolate: bool,
        regression_threshold: Optional[float] = None,
    ) -> Optional[JsonDict]:
        """Finds a hot-flip example for a given target example and DS example.

    Args:
      ref_example: target example for which the counterfactuals should be found.
      ds_example: a dataset example that should be used as a starting point for
        the search.
      features_to_consider: the list of feature keys that can be changed during
        the search.
      model: model to use for getting predictions.
      target_pred: model prediction that corresponds to `ref_example`.
      pred_key: the name of the field in model predictions that contains the
        prediction value for the counterfactual search.
      dataset: a dataset object that contains `ds_example`.
      interpolate: if True, the method tries to find a closer counterfactual
        using interpolation.
      regression_threshold: the threshold to use if `model` is a regression
        model. This parameter is ignored for classification models.

    Returns:
      A hot-flip counterfactual that satisfy the criteria.
    """
        # All features other than `features_to_consider` should be assigned the
        # value of the target example.
        candidate_example = ds_example.copy()
        for field_name in ref_example:
            if (field_name not in features_to_consider
                    and field_name in model.input_spec()):
                candidate_example[field_name] = ref_example[field_name]

        flip, predicted_value = self._is_flip(
            model=model,
            cf_example=candidate_example,
            orig_output=target_pred,
            pred_key=pred_key,
            regression_thresh=regression_threshold)

        if not flip:
            return None

        # Find closest flip by moving scalar values closer to the target.
        closest_flip = None
        if interpolate:
            closest_flip = self._find_closer_flip_using_interpolation(
                ref_example, candidate_example, target_pred, pred_key, model,
                dataset, regression_threshold)
        # If we found a closer flip through interpolation then use it,
        # otherwise use the previously found flip.
        if closest_flip is not None:
            return closest_flip
        else:
            self._find_dataset_parent_and_set(
                model_output_spec=model.output_spec(),
                pred_key=pred_key,
                dataset_spec=dataset.spec(),
                example=candidate_example,
                predicted_value=predicted_value)
            return candidate_example
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:

        # Perform validation and retrieve configuration.
        if not model:
            raise ValueError('Please provide a model for this generator.')

        config = config or {}
        num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT))
        max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT))

        pred_key = config.get(PREDICTION_KEY, '')
        regression_thresh = float(
            config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT))

        dataset_name = config.get('dataset_name')
        if not dataset_name:
            raise ValueError('The dataset name must be in the config.')

        output_spec = model.output_spec()
        if not pred_key:
            raise ValueError('Please provide the prediction key.')
        if pred_key not in output_spec:
            raise ValueError('Invalid prediction key.')

        if (not (isinstance(output_spec[pred_key], lit_types.MulticlassPreds)
                 or isinstance(output_spec[pred_key],
                               lit_types.RegressionScore))):
            raise ValueError(
                'Only classification and regression models are supported')

        # Calculate dataset statistics if it has never been calculated. The
        # statistics include such information as 'standard deviation' for scalar
        # features and probabilities for categorical features.
        if dataset_name not in self._datasets_stats:
            self._calculate_stats(dataset, dataset_name)

        # Find predicted class of the original example.
        original_pred = list(model.predict([example]))[0]

        # Find dataset examples that are flips.
        filtered_examples = self._filter_ds_examples(
            dataset=dataset,
            dataset_name=dataset_name,
            model=model,
            reference_output=original_pred,
            pred_key=pred_key,
            regression_thresh=regression_thresh)

        supported_field_names = self._find_all_fields_to_consider(
            ds_spec=dataset.spec(),
            model_input_spec=model.input_spec(),
            example=example)

        candidates: List[JsonDict] = []

        # Iterate through all possible feature combinations.
        combs = utils.find_all_combinations(supported_field_names, 1,
                                            max_flips)
        for comb in combs:
            # Sort all dataset examples with respect to the given combination.
            sorted_examples = self._sort_and_filter_examples(
                examples=filtered_examples,
                ref_example=example,
                fields=comb,
                dataset=dataset,
                dataset_name=dataset_name)
            if not sorted_examples:
                continue

            # As an optimization trick, check whether the farthest example is a flip.
            # If it is not a flip then skip the current combination of features.
            # This optimization makes the minimum set guarantees weaker but
            # significantly improves the search speed.
            flip = self._find_hot_flip(ref_example=example,
                                       ds_example=sorted_examples[-1],
                                       features_to_consider=comb,
                                       model=model,
                                       target_pred=original_pred,
                                       pred_key=pred_key,
                                       dataset=dataset,
                                       interpolate=False,
                                       regression_threshold=regression_thresh)
            if not flip:
                logging.info('Skipped combination %s', comb)
                continue

            # Iterate through the sorted examples until the first flip is found.
            # TODO(b/204200758): improve performance by batching the predict requests.
            for ds_example in sorted_examples:
                flip = self._find_hot_flip(
                    ref_example=example,
                    ds_example=ds_example,
                    features_to_consider=comb,
                    model=model,
                    target_pred=original_pred,
                    pred_key=pred_key,
                    dataset=dataset,
                    interpolate=True,
                    regression_threshold=regression_thresh)

                if flip:
                    self._add_if_not_strictly_worse(example=flip,
                                                    other_examples=candidates,
                                                    ref_example=example,
                                                    dataset=dataset,
                                                    dataset_name=dataset_name,
                                                    model=model)
                    break

            if len(candidates) >= num_examples:
                break

        # Calculate distances for the found hot flips.
        candidate_tuples = []
        for flip_example in candidates:
            distance, diff_fields = self._calculate_L1_distance(
                example_1=example,
                example_2=flip_example,
                dataset=dataset,
                dataset_name=dataset_name,
                model=model)
            if distance > 0:
                candidate_tuples.append((distance, diff_fields, flip_example))

        # Order the dataset entries based on the distance to the given example.
        candidate_tuples.sort(key=lambda e: e[0])

        if len(candidate_tuples) > num_examples:
            candidate_tuples = candidate_tuples[0:num_examples]

        # e[2] contains the hot-flip examples in the distances list of tuples.
        return [e[2] for e in candidate_tuples]
Exemple #19
0
 def is_compatible(self, model: lit_model.Model):
     text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
     pred_keys = utils.find_spec_keys(
         model.output_spec(),
         (types.MulticlassPreds, types.RegressionScore))
     return len(text_keys) and len(pred_keys)
Exemple #20
0
 def find_fields(self, model: lit_model.Model) -> List[str]:
   sal_keys = utils.find_spec_keys(
       model.output_spec(),
       (types.FeatureSalience, types.ImageSalience, types.TokenSalience,
        types.SequenceSalience))
   return sal_keys
Exemple #21
0
    def run(self,
            inputs: List[types.JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[types.JsonDict]] = None,
            config: Optional[types.JsonDict] = None):
        """Create PDP chart info using provided inputs.

    Args:
      inputs: sequence of inputs, following model.input_spec()
      model: optional model to use to generate new examples.
      dataset: dataset which the current examples belong to.
      model_outputs: optional precomputed model outputs
      config: optional runtime config.

    Returns:
      a dict of alternate feature values to model outputs. The model
      outputs will be a number for regression models and a list of numbers for
      multiclass models.
    """

        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('PDP did not find any supported output fields.')
            return None

        assert 'feature' in config, 'No feature to test provided'
        feature = config['feature']
        provided_range = config['range'] if 'range' in config else []
        edited_outputs = {}
        for pred_key in pred_keys:
            edited_outputs[pred_key] = {}

        # If a range was provided, use that to create the possible values.
        vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10)
                        if len(provided_range) == 2 else self.get_vals_to_test(
                            feature, dataset))

        # If no specific inputs provided, use the entire dataset.
        inputs_to_use = inputs if inputs else dataset.examples

        # For each alternate value for a given feature.
        for new_val in vals_to_test:
            # Create copies of all provided inputs with the value replaced.
            edited_inputs = []
            for inp in inputs_to_use:
                edited_input = copy.deepcopy(inp)
                edited_input[feature] = new_val
                edited_inputs.append(edited_input)

            # Run prediction on the altered inputs.
            outputs = list(model.predict(edited_inputs))

            # Store the mean of the prediction for the alternate value.
            for pred_key in pred_keys:
                numeric = isinstance(model.output_spec()[pred_key],
                                     types.RegressionScore)
                if numeric:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs])
                else:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs], axis=0)

        return edited_outputs
Exemple #22
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Identify minimal sets of token flips that alter the prediction."""
        del dataset  # Unused.

        config = config or {}
        num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT))
        max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT))
        tokens_to_ignore = config.get(TOKENS_TO_IGNORE_KEY,
                                      TOKENS_TO_IGNORE_DEFAULT)
        pred_key = config.get(PREDICTION_KEY, "")
        regression_thresh = float(
            config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT))
        assert model is not None, "Please provide a model for this generator."

        input_spec = model.input_spec()
        output_spec = model.output_spec()
        assert pred_key, "Please provide the prediction key"
        assert pred_key in output_spec, "Invalid prediction key"

        is_regression = False
        if isinstance(output_spec[pred_key], types.RegressionScore):
            is_regression = True
        else:
            assert isinstance(output_spec[pred_key], types.MulticlassPreds), (
                "Only classification or regression models are supported")
        logging.info(r"W3lc0m3 t0 H0tFl1p \o/")
        logging.info("Original example: %r", example)

        # Get model outputs.
        orig_output = list(model.predict([example]))[0]

        # Check config for selected fields.
        selected_fields = list(config.get(FIELDS_TO_HOTFLIP_KEY, []))
        if not selected_fields:
            return []

        # Get tokens (corresponding to each text input field) and corresponding
        # gradients.
        tokens_and_gradients = self._get_tokens_and_gradients(
            input_spec, output_spec, orig_output, selected_fields)
        assert tokens_and_gradients, (
            "No token fields found. Cannot use HotFlip. :-(")

        # Copy tokens into input example.
        example = copy.deepcopy(example)
        for token_field, v in tokens_and_gradients.items():
            tokens, _ = v
            example[token_field] = tokens

        inv_vocab, embedding_matrix = model.get_embedding_table()
        assert len(inv_vocab) == embedding_matrix.shape[0], (
            "Vocab/embeddings size mismatch.")

        successful_cfs = []
        # TODO(lit-team): use only 1 sequence as input (configurable in UI).
        # TODO(lit-team): Refactor the following code so that it's not so deeply
        # nested (and easier to track loop state).
        for token_field, v in tokens_and_gradients.items():
            tokens, grads = v
            text_field = input_spec[token_field].parent  # pytype: disable=attribute-error
            logging.info("Identifying Hotflips for input field: %s",
                         str(text_field))
            direction = -1
            if is_regression:
                # We want the replacements to increase the prediction score if the
                # original score is below the threshold, and decrease otherwise.
                direction = (1 if orig_output[pred_key] <= regression_thresh
                             else -1)
            replacement_tokens = self._get_replacement_tokens(
                embedding_matrix, inv_vocab, grads, direction)

            successful_positions = []
            for token_idxs in self._gen_token_idxs_to_flip(
                    tokens, grads, max_flips, tokens_to_ignore):
                if len(successful_cfs) >= num_examples:
                    return successful_cfs
                # If a subset of the set of tokens have already been successful in
                # obtaining a flip, we continue. This ensures that we only consider
                # sets of token flips that are minimal.
                if self._subset_exists(set(token_idxs), successful_positions):
                    continue

                # Create counterfactual.
                cf = self._create_cf(example, token_field, text_field, tokens,
                                     token_idxs, replacement_tokens)
                # Obtain model prediction.
                cf_output = list(model.predict([cf]))[0]

                if cf_utils.is_prediction_flip(cf_output, orig_output,
                                               output_spec, pred_key,
                                               regression_thresh):
                    # Prediciton flip found!
                    cf_utils.update_prediction(cf, cf_output, output_spec,
                                               pred_key)
                    successful_cfs.append(cf)
                    successful_positions.append(set(token_idxs))
        return successful_cfs
Exemple #23
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None,
                 num_examples: int = 1) -> List[JsonDict]:
        """Use gradient to find/substitute the token with largest impact on loss."""
        # TODO(lit-team): This function is quite long. Consider breaking it
        # into small functions.
        del dataset  # Unused.

        assert model is not None, "Please provide a model for this generator."
        logging.info(r"W3lc0m3 t0 H0tFl1p \o/")
        logging.info("Original example: %r", example)

        # Find classification prediciton key.
        pred_keys = self.find_fields(model.output_spec(),
                                     types.MulticlassPreds, None)
        if len(pred_keys) == 0:  # pylint: disable=g-explicit-length-test
            # TODO(ataly): Add support for regression models.
            logging.warning("The model does not have a classification head."
                            "Cannot use HotFlip. :-(")
            return []  # Cannot generate examples.
        if len(pred_keys) > 1:
            # TODO(ataly): Use a config argument when there are multiple prediction
            # heads.
            logging.warning("Multiple classification heads found."
                            "Cannot use HotFlip. :-(")
            return []  # Cannot generate examples.
        pred_key = pred_keys[0]

        # Find gradient fields to use for HotFlip
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec, types.TokenGradients,
                                       types.Tokens)
        logging.info("Found gradient fields for HotFlip use: %s",
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            logging.info("No gradient fields found. Cannot use HotFlip. :-(")
            return []  # Cannot generate examples without gradients.

        # Get model outputs.
        logging.info(
            "Performing a forward/backward pass on the input example.")
        orig_output = list(model.predict([example]))[0]
        logging.info(orig_output.keys())

        # Get model word embeddings and vocab.
        inv_vocab, embed = model.get_embedding_table()
        assert len(
            inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch."
        logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab),
                     embed.shape)

        # Get original prediction class
        orig_probabilities = orig_output[pred_key]
        orig_prediction = np.argmax(orig_probabilities)

        # Perform a flip in each sequence for which we have gradients (separately).
        # Each sequence may give rise to multiple new examples, depending on how
        # many words we flip.
        # TODO(lit-team): make configurable how many new examples are desired.
        # TODO(lit-team): use only 1 sequence as input (configurable in UI).
        new_examples = []
        for grad_field in grad_fields:
            # Get the tokens and their gradient vectors.
            token_field = output_spec[grad_field].align  # pytype: disable=attribute-error
            tokens = orig_output[token_field]
            grads = orig_output[grad_field]
            token_emb_fields = self.find_fields(output_spec,
                                                types.TokenEmbeddings,
                                                types.Tokens)
            assert len(
                token_emb_fields) == 1, "Found multiple token embeddings"
            token_embs = orig_output[token_emb_fields[0]]

            # Identify the token with the largest gradient attribution,
            # defined as the dot product between the token embedding and gradient
            # of the output wrt the embedding.
            assert token_embs.shape[0] == grads.shape[0]
            token_grad_attrs = np.sum(token_embs * grads, axis=-1)
            # Get a list of indices of input tokens, sorted by gradient attribution,
            # highest first. We will flip tokens in this order.
            sorted_by_grad_attrs = np.argsort(token_grad_attrs)[::-1]

            for i in range(min(num_examples, len(tokens))):
                token_id = sorted_by_grad_attrs[i]
                logging.info(
                    "Selected token: %s (pos=%d) with gradient attribution %f",
                    tokens[token_id], token_id, token_grad_attrs[token_id])
                token_grad = grads[token_id]

                # Take dot product with all word embeddings. Get smallest value.
                # (We are look for a replacement token that will lower the score
                # the current class, thereby increasing the chances of a label
                # flip.)
                # TODO(lit-team): Can add criteria to the winner e.g. cosine distance.
                scores = np.dot(embed, token_grad)
                winner = np.argmin(scores)
                logging.info(
                    "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)",
                    tokens[token_id], token_id, i, inv_vocab[winner], winner)

                # Create a new input to the model.
                # TODO(iftenney, bastings): enforce somewhere that this field has the
                # same name in the input and output specs.
                input_token_field = token_field
                input_text_field = input_spec[input_token_field].parent  # pytype: disable=attribute-error
                new_example = copy.deepcopy(example)
                modified_tokens = copy.copy(tokens)
                modified_tokens[token_id] = inv_vocab[winner]
                new_example[input_token_field] = modified_tokens
                # TODO(iftenney, bastings): call a model-provided detokenizer here?
                # Though in general tokenization isn't invertible and it's possible for
                # HotFlip to produce wordpiece sequences that don't correspond to any
                # input string.
                new_example[input_text_field] = " ".join(modified_tokens)

                # Predict a new label for this example.
                new_output = list(model.predict([new_example]))[0]

                # Update label if multi-class prediction.
                # TODO(lit-dev): provide a general system for handling labels on
                # generated examples.
                probabilities = new_output[pred_key]
                new_prediction = np.argmax(probabilities)
                label_key = cast(types.MulticlassPreds,
                                 output_spec[pred_key]).parent
                label_names = cast(types.MulticlassPreds,
                                   output_spec[pred_key]).vocab
                new_label = label_names[new_prediction]
                new_example[label_key] = new_label
                logging.info("Updated example with new label: %s", new_label)

                if new_prediction != orig_prediction:
                    # Hotflip found
                    new_examples.append(new_example)
                else:
                    # We make new_example as our base example and continue with more
                    # token flips.
                    example = new_example
                    tokens = modified_tokens
        return new_examples
Exemple #24
0
    def run_with_metadata(
            self,
            indexed_inputs: Sequence[IndexedInput],
            model: lit_model.Model,
            dataset: lit_dataset.IndexedDataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Calculates optimal thresholds on the provided data.

    Args:
      indexed_inputs: all examples in the dataset, in the indexed input format.
      model: the model being explained.
      dataset: the dataset which the current examples belong to.
      model_outputs: optional model outputs from calling model.predict(inputs).
      config: a config which should specify TresholderConfig options.

    Returns:
      A JsonDict containing the calcuated thresholds
    """
        config = TresholderConfig(**config) if config else TresholderConfig()

        pred_keys = []
        for pred_key, field_spec in model.output_spec().items():
            if self.metrics_gen.is_compatible(field_spec) and cast(
                    types.MulticlassPreds, field_spec).parent:
                pred_keys.append(pred_key)

        indexed_outputs = {
            ex['id']: output
            for (ex, output) in zip(indexed_inputs, model_outputs)
        }

        # Try all margins for thresholds from 0 to 1, by hundreths.
        margins_to_try = [
            self.threshold_to_margin(t) for t in np.linspace(0, 1, 101)
        ]

        # Get binary classification metrics for all margins, for the entire
        # dataset, and also for each facet specified in the config.
        dataset_results = []
        faceted_results = {}
        # Loop over each margin/threshold to check.
        for margin in margins_to_try:
            # Set up an empty config to pass to the metrics generator.
            metrics_config = {}
            for pred_key in pred_keys:
                metrics_config[pred_key] = {'': {'margin': margin}}

            # Get and store the metrics for the entire dataset for this margin.
            dataset_results.append(
                self.metrics_gen.run_with_metadata(indexed_inputs, model,
                                                   dataset, model_outputs,
                                                   metrics_config))

            # Get and store the metrics for each facet of the dataset for this margin.
            for facet_key in config.facets:
                if 'data' not in config.facets[facet_key]:
                    continue
                if facet_key not in faceted_results:
                    faceted_results[facet_key] = []
                faceted_model_outputs = [
                    indexed_outputs[ex['id']]
                    for ex in config.facets[facet_key]['data']
                ]
                faceted_results[facet_key].append(
                    self.metrics_gen.run_with_metadata(
                        config.facets[facet_key]['data'], model, dataset,
                        faceted_model_outputs, metrics_config))

        pred_keys = [result['pred_key'] for result in dataset_results[0]]
        ret = []

        # Find threshold information for each prediction key.
        for i, pred_key in enumerate(pred_keys):
            ret.append(
                self.get_thresholds_for_pred_key(pred_key, i, margins_to_try,
                                                 dataset_results,
                                                 faceted_results, config))
        return ret
Exemple #25
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Runs the component, given a model and input(s)."""
        input_spec = model.input_spec()
        output_spec = model.output_spec()

        if not inputs:
            return []

        # Find all fields required for the interpretation.
        supported_fields = find_supported_fields(input_spec, output_spec)
        if supported_fields is None:
            return
        grad_field_key = supported_fields.grad_field_key
        image_field_key = supported_fields.image_field_key
        grad_target_field_key = supported_fields.grad_target_field_key
        preds_field_key = supported_fields.preds_field_key
        multiclass = grad_target_field_key is not None

        # Determine the shape of gradients by calling the model with a single input
        # and extracting the shape from the gradient output.
        model_output = list(model.predict([inputs[0]]))
        grad_shape = model_output[0][grad_field_key].shape

        # If it is a multiclass model, find the labels with respect to which we
        # should compute the gradients.
        if multiclass:
            # Get class labels.
            label_vocab = list(
                cast(types.MulticlassPreds,
                     output_spec[preds_field_key]).vocab)

            # Run the model in order to find the gradient target labels.
            outputs = list(model.predict(inputs))
            grad_target_labels = []
            for model_input, model_output in zip(inputs, outputs):
                if model_input.get(grad_target_field_key) is not None:
                    grad_target_labels.append(
                        model_input[grad_target_field_key])
                else:
                    max_idx = int(np.argmax(model_output[preds_field_key]))
                    grad_target_labels.append(label_vocab[max_idx])
        else:
            grad_target_labels = [None] * len(inputs)

        saliency_object = self.get_saliency_object()
        extra_saliency_params = self.get_extra_saliency_params(config)
        all_results = []
        for example, grad_target_label in zip(inputs, grad_target_labels):
            result = {}
            image_str = example[image_field_key]
            saliency_input = image_utils.convert_image_str_to_array(
                image_str=image_str, shape=grad_shape)
            call_model_func = get_call_model_func(
                model=model,
                model_input=example,
                image_field_key=image_field_key,
                grad_field_key=grad_field_key,
                grad_target_field_key=grad_target_field_key,
                grad_target_label=grad_target_label)
            attribution = self.make_saliency_call(
                saliency_object=saliency_object,
                x_value=saliency_input,
                call_model_function=call_model_func,
                extra_saliency_params=extra_saliency_params)
            if attribution.ndim == 3:
                attribution = attribution.sum(axis=2)
            viz_params = self.get_visualization_params()
            overlaid_image = image_utils.overlay_pixel_saliency(
                image_str, attribution, **viz_params)
            result[grad_field_key] = image_utils.convert_pil_to_image_str(
                overlaid_image)
            all_results.append(result)
        return all_results
Exemple #26
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        config_defaults = {k: v.default for k, v in self.config_spec().items()}
        config = dict(config_defaults, **(config or {}))  # update and return

        provided_class_to_explain = int(config[CLASS_KEY])
        kernel_width = int(config[KERNEL_WIDTH_KEY])
        num_samples = int(config[NUM_SAMPLES_KEY])
        mask_string = (config[MASK_KEY])
        # pylint: disable=g-explicit-bool-comparison
        seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None
        # pylint: enable=g-explicit-bool-comparison

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(), (types.MulticlassPreds, types.RegressionScore,
                                  types.SparseMultilabelPreds))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None
        pred_key = config[TARGET_HEAD_KEY] or pred_keys[0]

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(
                _predict_fn,
                model=model,
                original_example=input_,
                pred_key=pred_key,
                pred_type_info=model.output_spec()[pred_key])

            class_to_explain = get_class_to_explain(provided_class_to_explain,
                                                    model, pred_key, input_)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=class_to_explain,
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.TokenSalience(
                    input_string.split(), scores)

            all_results.append(result)

        return all_results
    def run(
            self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None,
            kernel_width: int = 25,  # TODO(lit-dev): make configurable in UI.
            mask_string:
        str = '[MASK]',  # TODO(lit-dev): make configurable in UI.
            num_samples: int = 256,  # TODO(lit-dev): make configurable in UI.
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(model.output_spec(),
                                         types.MulticlassPreds)
        if not pred_keys:
            logging.warning(
                'LIME did not find a multi-class predictions field.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key])
        label_names = pred_spec.vocab

        # Create a LIME text explainer instance.
        explainer = lime_text.LimeTextExplainer(
            class_names=label_names,
            split_expression=str.split,
            kernel_width=kernel_width,
            mask_string=mask_string,  # This is the string used to mask words.
            bow=False
        )  # bow=False masks inputs, instead of deleting them entirely.

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                logging.info('Explaining: %s', input_string)

                # Use the number of words as the number of features.
                num_features = len(input_string.split())

                def _predict_proba(strings: List[Text]):
                    """Given raw strings, return probabilities. Used by `explainer`."""
                    input_examples = [
                        new_example(input_, text_key, s) for s in strings
                    ]
                    model_outputs = model.predict(input_examples)
                    probs = np.array(
                        [output[pred_key] for output in model_outputs])
                    return probs  # <float32>[len(strings), num_labels]

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = explainer.explain_instance(
                    input_string,
                    _predict_proba,
                    num_features=num_features,
                    num_samples=num_samples)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation_to_array(explanation)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
Exemple #28
0
  def run_with_metadata(
      self,
      indexed_inputs: Sequence[IndexedInput],
      model: lit_model.Model,
      dataset: lit_dataset.IndexedDataset,
      model_outputs: Optional[List[JsonDict]] = None,
      config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
    """Runs the TCAV method given the params in the inputs and config.

    Args:
      indexed_inputs: all examples in the dataset, in the indexed input format.
      model: the model being explained.
      dataset: the dataset which the current examples belong to.
      model_outputs: optional model outputs from calling model.predict(inputs).
      config: a config which should specify:
        {
          'concept_set_ids': [list of ids to use in concept set]
          'class_to_explain': [gradient class to explain],
          'grad_layer': [the Gradient field key of the layer to explain],
          'random_state': [an optional seed to make outputs deterministic]
        }

    Returns:
      A JsonDict containing the TCAV scores, directional derivatives,
      statistical test p-values, and LM accuracies.
    """
    config = TCAVConfig(**config)
    # TODO(b/171513556): get these from the Dataset object once indices are
    # available there.
    dataset_examples = indexed_inputs

    # Get this layer's output spec keys for gradients and embeddings.
    grad_layer = config.grad_layer
    output_spec = model.output_spec()
    emb_layer = cast(types.Gradients, output_spec[grad_layer]).grad_for

    # Get the class that the gradients were computed for.
    grad_class_key = cast(types.Gradients, output_spec[grad_layer]).grad_target

    ids_set = set(config.concept_set_ids)
    concept_set = [ex for ex in dataset_examples if ex['id'] in ids_set]
    non_concept_set = [ex for ex in dataset_examples if ex['id'] not in ids_set]

    # Get outputs using model.predict().
    dataset_outputs = list(model.predict_with_metadata(dataset_examples))

    def _subsample(examples, n):
      return random.sample(examples, n) if n < len(examples) else examples

    concept_outputs = list(model.predict_with_metadata(concept_set))
    non_concept_outputs = list(model.predict_with_metadata(non_concept_set))

    concept_results = []
    # If there are more concept set examples than non-concept set examples, we
    # use random splits of the concept examples as the concept set and use the
    # remainder of the dataset as the comparison set. Otherwise, we use random
    # splits of the non-concept examples as the comparison set.
    n = min(len(concept_set), len(non_concept_set))

    # If there are an equal number of concept and non-concept examples, we
    # decrease n by one so that we also sample a different set in each TCAV run.
    if len(concept_set) == len(non_concept_set):
      n -= 1
    for _ in range(NUM_SPLITS):
      concept_split_outputs = _subsample(concept_outputs, n)
      comparison_split_outputs = _subsample(non_concept_outputs, n)
      concept_results.append(self._run_tcav(concept_split_outputs,
                                            comparison_split_outputs,
                                            dataset_outputs,
                                            config.class_to_explain,
                                            emb_layer,
                                            grad_layer,
                                            grad_class_key,
                                            config.test_size,
                                            config.random_state))

    cav_scores = [res['score'] for res in concept_results]
    p_val = self.hyp_test(cav_scores)

    # Get index of CAV result with the highest accuracy.
    accuracies = [res['accuracy'] for res in concept_results]
    index = np.argmax(accuracies)

    # Many CAVS are trained and checked for statistical testing to calculate
    # the p-value. The values of the first CAV are returned.
    results = {'result': concept_results[index], 'p_val': p_val}
    return [results]