예제 #1
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        # Find gradient fields to interpret
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec)
        logging.info('Found fields for gradient attribution: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))
        assert len(model_outputs) == len(inputs)

        all_results = []
        for o in model_outputs:
            # Dict[field name -> interpretations]
            result = {}
            for grad_field in grad_fields:
                token_field = cast(types.TokenGradients,
                                   output_spec[grad_field]).align
                tokens = o[token_field]
                scores = self._interpret(o[grad_field], tokens)
                result[grad_field] = dtypes.SalienceMap(tokens, scores)
            all_results.append(result)

        return all_results
예제 #2
0
파일: t5.py 프로젝트: PAIR-code/lit
def validate_t5_model(model: lit_model.Model) -> lit_model.Model:
    """Validate that a given model looks like a T5 model.

  This checks the model spec at runtime; it is intended to be used before server
  start, such as in the __init__() method of a wrapper class.

  Args:
    model: a LIT model

  Returns:
    model: the same model

  Raises:
    AssertionError: if the model's spec does not match that expected for a T5
    model.
  """
    # Check inputs
    ispec = model.input_spec()
    assert "input_text" in ispec
    assert isinstance(ispec["input_text"], lit_types.TextSegment)
    if "target_text" in ispec:
        assert isinstance(ispec["target_text"], lit_types.TextSegment)

    # Check outputs
    ospec = model.output_spec()
    assert "output_text" in ospec
    assert isinstance(
        ospec["output_text"],
        (lit_types.GeneratedText, lit_types.GeneratedTextCandidates))
    assert ospec["output_text"].parent == "target_text"

    return model
예제 #3
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        # Find gradient fields to interpret
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(input_spec, output_spec)
        logging.info('Found fields for integrated gradients: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        all_results = []
        for model_output, model_input in zip(model_outputs, inputs):
            result = self.get_salience_result(model_input, model, model_output,
                                              grad_fields)
            all_results.append(result)
        return all_results
예제 #4
0
def _get_results_from_model(model: lit_model.Model, data: lit_dataset.Dataset, notebook: bool) -> List[Dict]:
    tqdm = notebook_tqdm if notebook else normal_tqdm

    batch_size = model.max_minibatch_size()
    results = []
    for i in tqdm(range(0, len(data), batch_size), desc='processing batches'):
        batch_examples = data.examples[i:i + batch_size]
        results.extend(model.predict_minibatch(batch_examples))

    return results
예제 #5
0
    def run_with_metadata(
            self,
            indexed_inputs: Sequence[IndexedInput],
            model: lit_model.Model,
            dataset: lit_dataset.IndexedDataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Finds the nearest neighbors of the example specified in the config.

    Args:
      indexed_inputs: the dataset example to find nearest neighbors for.
      model: the model being explained.
      dataset: the dataset which the current examples belong to.
      model_outputs: optional model outputs from calling model.predict(inputs).
      config: a config which should specify:
        {
          'num_neighbors': [the number of nearest neighbors to return]
          'dataset_name': [the name of the dataset (used for caching)]
          'embedding_name': [the name of the embedding field to use]
        }

    Returns:
      A JsonDict containing the a list of num_neighbors nearest neighbors,
      where each has the example id and distance from the main example.
    """
        config = NearestNeighborsConfig(**config)

        dataset_outputs = list(
            model.predict_with_metadata(dataset.indexed_examples,
                                        dataset_name=config.dataset_name))

        example_outputs = list(
            model.predict_with_metadata(indexed_inputs,
                                        dataset_name=config.dataset_name))
        # TODO(lit-dev): Add support for selecting nearest neighbors of a set.
        if len(example_outputs) != 1:
            raise ValueError('More than one selected example was passed in.')
        example_output = example_outputs[0]

        # <float32>[emb_size]
        dataset_embs = [
            output[config.embedding_name] for output in dataset_outputs
        ]
        example_embs = [example_output[config.embedding_name]]
        distances = distance.cdist(example_embs, dataset_embs)[0]
        sorted_indices = np.argsort(distances)
        k = config.num_neighbors
        k_nearest_neighbors = [{
            'id':
            dataset.indexed_examples[original_index]['id'],
            'nn_distance':
            distances[original_index]
        } for original_index in sorted_indices[:k]]

        return [{'nearest_neighbors': k_nearest_neighbors}]
    def _is_flip(
            self,
            model: lit_model.Model,
            cf_example: JsonDict,
            orig_output: JsonDict,
            pred_key: Text,
            regression_thresh: Optional[float] = None) -> Tuple[bool, Any]:

        cf_output = list(model.predict([cf_example]))[0]
        feature_predicted_value = cf_output[pred_key]
        return cf_utils.is_prediction_flip(
            cf_output=cf_output,
            orig_output=orig_output,
            output_spec=model.output_spec(),
            pred_key=pred_key,
            regression_thresh=regression_thresh), feature_predicted_value
예제 #7
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        if not inputs: return

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LEMON requires text inputs.')
            return None
        logging.info('Found text fields for LEMON attribution: %s',
                     str(text_keys))

        pred_key = config['pred_key']
        output_probs = np.array([output[pred_key] for output in model_outputs])

        # Explain the input given counterfactuals.

        # Dict[field name -> interpretations]
        result = {}

        # Explain each text segment in the input, keeping the others constant.
        for text_key in text_keys:
            sentences = [item[text_key] for item in inputs]
            input_to_prediction = dict(zip(sentences, output_probs))

            input_string = sentences[0]
            counterfactuals = sentences[1:]

            # Remove duplicate counterfactuals.
            counterfactuals = list(set(counterfactuals))

            logging.info('Explaining: %s', input_string)

            predict_proba = make_predict_fn(input_to_prediction)

            # Perturbs the input string, gets model predictions, fits linear model.
            explanation = lemon.explain(
                input_string,
                counterfactuals,
                predict_proba,
                class_to_explain=config['class_to_explain'],
                lowercase_tokens=config['lowercase_tokens'])

            scores = np.array(explanation.feature_importance)

            # Normalize feature values.
            scores = citrus_utils.normalize_scores(scores)

            result[text_key] = dtypes.TokenSalience(input_string.split(),
                                                    scores)

        return [result]
예제 #8
0
파일: metrics.py 프로젝트: PAIR-code/lit
    def run_with_metadata(self,
                          indexed_inputs: Sequence[IndexedInput],
                          model: lit_model.Model,
                          dataset: lit_dataset.IndexedDataset,
                          model_outputs: Optional[List[JsonDict]] = None,
                          config: Optional[JsonDict] = None) -> List[JsonDict]:
        if model_outputs is None:
            model_outputs = list(model.predict_with_metadata(indexed_inputs))

        # TODO(lit-team): pre-compute this mapping in constructor?
        # This would require passing a model name to this function so we can
        # reference a pre-computed list.
        spec = model.spec()
        field_map = map_pred_keys(dataset.spec(), spec.output,
                                  self.is_compatible)
        ret = []
        for pred_key, label_key in field_map.items():
            # Extract fields
            labels = [ex['data'][label_key] for ex in indexed_inputs]
            preds = [mo[pred_key] for mo in model_outputs]
            indices = [ex['id'] for ex in indexed_inputs]
            metas = [ex.get('meta', {}) for ex in indexed_inputs]
            # Compute metrics, as dict(str -> float)
            metrics = self.compute_with_metadata(
                labels,
                preds,
                label_spec=dataset.spec()[label_key],
                pred_spec=spec.output[pred_key],
                indices=indices,
                metas=metas,
                config=config.get(pred_key) if config else None)
            # NaN is not a valid JSON value, so replace with None which will be
            # serialized as null.
            # TODO(lit-team): move this logic into serialize.py somewhere instead?
            metrics = {
                k: (v if not np.isnan(v) else None)
                for k, v in metrics.items()
            }
            # Format for frontend.
            ret.append({
                'pred_key': pred_key,
                'label_key': label_key,
                'metrics': metrics
            })
        return ret
예제 #9
0
 def _get_embedding(self, example: JsonDict, model: lit_model.Model,
                    dataset: lit_dataset.IndexedDataset,
                    embedding_name: str, dataset_name: str):
     """Calls the model on the example to get the embedding."""
     model_input = dataset.index_inputs([example])
     model_output = model.predict_with_metadata(model_input,
                                                dataset_name=dataset_name)
     embedding = list(model_output)[0][embedding_name]
     return embedding
예제 #10
0
    def run_with_metadata(
            self,
            indexed_inputs: Sequence[IndexedInput],
            model: lit_model.Model,
            dataset: lit_dataset.IndexedDataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[JsonDict]:
        """Run this component, given a model and input(s).

    Args:
      indexed_inputs: Inputs to cluster.
      model: Model that provides salience maps.
      dataset: Dataset to compute salience maps for.
      model_outputs: Precomputed model outputs.
      config: Config for clustering and salience computation

    Returns:
      Dict with 2 keys:
        `CLUSTER_ID_KEY`: Contains the cluster assignments. One cluster id per
          dataset example.
        `REPRESENTATION_KEY`: Contains the representations of all examples in
          the dataset that were used in the clustering.
    """
        config = config or {}
        # Find gradient fields to interpret
        grad_fields = self.find_fields(model.output_spec())
        token_saliencies = self.salience_mappers[
            config['salience_mapper']].run_with_metadata(
                indexed_inputs, model, dataset, model_outputs, config)

        if not token_saliencies:
            return None

        vocab = self._build_vocab(token_saliencies)
        representations = self._compute_fixed_length_representation(
            token_saliencies, vocab)

        cluster_ids = {}
        grad_field_to_representations = {}

        for grad_field in grad_fields:
            weight_matrix = np.vstack(representation[grad_field]
                                      for representation in representations)
            self.kmeans[grad_field] = cluster.KMeans(n_clusters=config.get(
                N_CLUSTERS_KEY,
                self.config_spec()[N_CLUSTERS_KEY].default))
            cluster_ids[grad_field] = self.kmeans[grad_field].fit_predict(
                weight_matrix).tolist()
            grad_field_to_representations[grad_field] = weight_matrix

        return {
            CLUSTER_ID_KEY: cluster_ids,
            REPRESENTATION_KEY: grad_field_to_representations
        }
예제 #11
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        config = config or {}
        class_to_explain = config.get(CLASS_KEY,
                                      self.config_spec()[CLASS_KEY].default)
        interpolation_steps = int(
            config.get(INTERPOLATION_KEY,
                       self.config_spec()[INTERPOLATION_KEY].default))
        normalization = config.get(
            NORMALIZATION_KEY,
            self.config_spec()[NORMALIZATION_KEY].default)

        # Find gradient fields to interpret
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(input_spec, output_spec)
        logging.info('Found fields for integrated gradients: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        all_results = []
        for model_output, model_input in zip(model_outputs, inputs):
            result = self.get_salience_result(model_input, model,
                                              interpolation_steps,
                                              normalization, class_to_explain,
                                              model_output, grad_fields)
            all_results.append(result)
        return all_results
예제 #12
0
파일: metrics.py 프로젝트: PAIR-code/lit
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None):
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        spec = model.spec()
        field_map = map_pred_keys(dataset.spec(), spec.output,
                                  self.is_compatible)
        ret = []
        for pred_key, label_key in field_map.items():
            # Extract fields
            labels = [ex[label_key] for ex in inputs]
            preds = [mo[pred_key] for mo in model_outputs]
            # Compute metrics, as dict(str -> float)
            metrics = self.compute(
                labels,
                preds,
                label_spec=dataset.spec()[label_key],
                pred_spec=spec.output[pred_key],
                config=config.get(pred_key) if config else None)
            # NaN is not a valid JSON value, so replace with None which will be
            # serialized as null.
            # TODO(lit-team): move this logic into serialize.py somewhere instead?
            metrics = {
                k: (v if not np.isnan(v) else None)
                for k, v in metrics.items()
            }
            # Format for frontend.
            ret.append({
                'pred_key': pred_key,
                'label_key': label_key,
                'metrics': metrics
            })
        return ret
    def _filter_ds_examples(
            self,
            dataset: lit_dataset.IndexedDataset,
            dataset_name: Text,
            model: lit_model.Model,
            reference_output: JsonDict,
            pred_key: Text,
            regression_thresh: Optional[float] = None) -> List[JsonDict]:
        """Reads all dataset examples and returns only those that are flips."""
        if not isinstance(dataset, lit_dataset.IndexedDataset):
            raise ValueError(
                'Only indexed datasets are currently supported by the TabularMTC'
                'generator.')

        indexed_examples = list(dataset.indexed_examples)
        filtered_examples = []
        preds = model.predict_with_metadata(indexed_examples,
                                            dataset_name=dataset_name)

        # Find all DS examples that are flips with respect to the reference example.
        for indexed_example, pred in zip(indexed_examples, preds):
            flip = cf_utils.is_prediction_flip(
                cf_output=pred,
                orig_output=reference_output,
                output_spec=model.output_spec(),
                pred_key=pred_key,
                regression_thresh=regression_thresh)
            if flip:
                candidate_example = indexed_example['data'].copy()
                self._find_dataset_parent_and_set(
                    model_output_spec=model.output_spec(),
                    pred_key=pred_key,
                    dataset_spec=dataset.spec(),
                    example=candidate_example,
                    predicted_value=pred[pred_key])
                filtered_examples.append(candidate_example)
        return filtered_examples
예제 #14
0
  def _run(self,
           model: lit_model.Model,
           indexed_inputs: List[JsonDict],
           model_outputs: Optional[List[JsonDict]] = None,
           do_fit=False):
    # Run model, if needed.
    if model_outputs is None:
      model_outputs = list(model.predict(indexed_inputs))
    assert len(model_outputs) == len(indexed_inputs)

    converted_inputs = list(
        map(self.convert_input, indexed_inputs, model_outputs))
    if do_fit:
      return self._projector.fit_transform_with_metadata(
          converted_inputs, dataset_name="")
    else:
      return self._projector.predict_with_metadata(
          converted_inputs, dataset_name="")
예제 #15
0
파일: metrics.py 프로젝트: PAIR-code/lit
 def run(self,
         inputs: List[JsonDict],
         model: lit_model.Model,
         dataset: lit_dataset.Dataset,
         model_outputs: Optional[List[JsonDict]] = None,
         config: Optional[JsonDict] = None):
     # Get margin for each input for each pred key and add them to a config dict
     # to pass to the wrapped metrics.
     field_map = map_pred_keys(dataset.spec(),
                               model.spec().output, self.is_compatible)
     margin_config = {}
     for pred_key in field_map:
         field_config = config.get(pred_key) if config else None
         margins = [
             get_margin_for_input(field_config, inp) for inp in inputs
         ]
         margin_config[pred_key] = margins
     return self._metrics.run(inputs, model, dataset, model_outputs,
                              margin_config)
예제 #16
0
  def build(cls,
            inputs: List[JsonDict],
            encoder: lit_model.Model,
            edge_field: str,
            embs_field: str,
            offset_field: str,
            progress=lambda x: x,
            verbose=False):
    """Run encoder and extract span representations for coreference.

    'encoder' should be a model returning one TokenEmbeddings field,
    from which span features will be extracted, as well as a TokenOffsets field
    which maps input tokens to output tokens.

    The returned dataset will contain one example for each label in the inputs'
    EdgeLabels field.

    Args:
      inputs: input Dataset
      encoder: encoder model, compatible with inputs
      edge_field: name of edge field in data
      embs_field: name of embeddings field in model output
      offset_field: name of offset field in model output
      progress: optional pass-through progress indicator
      verbose: if true, print estimated memory usage

    Returns:
      EdgeFeaturesDataset with extracted span representations
    """
    examples = []
    encoder_outputs = progress(encoder.predict(inputs))
    for i, output in enumerate(encoder_outputs):
      exs = _make_probe_inputs(
          inputs[i][edge_field],
          output[embs_field],
          output[offset_field],
          src_idx=i)
      examples.extend(exs)
      if verbose and i == 10:
        _estimate_memory_needs(inputs, edge_field, examples[0])

    return cls(examples)
예제 #17
0
 def _generate_leave_one_out_ablation_score(
         self, example: JsonDict, model: lit_model.Model, input_spec: Spec,
         output_spec: Spec, orig_output: JsonDict, pred_key: Text,
         fields_to_ablate: List[str]) -> List[Tuple[str, int, float]]:
     # Returns a list of triples: field, token_idx and leave-one-out score.
     ret = []
     for field in input_spec.keys():
         if field not in example or field not in fields_to_ablate:
             continue
         tokens = self._get_tokens(example, input_spec, field)
         cfs = [
             self._create_cf(example, input_spec, [(field, i)])
             for i in range(len(tokens))
         ]
         cf_outputs = model.predict(cfs)
         for i, cf_output in enumerate(cf_outputs):
             loo_score = cf_utils.prediction_difference(
                 cf_output, orig_output, output_spec, pred_key)
             ret.append((field, i, loo_score))
     return ret
예제 #18
0
 def _train_instance(self, model: lit_model.Model,
                     dataset: lit_dataset.IndexedDataset, config: JsonDict,
                     name: Text) -> ProjectionInterpreter:
     # Ignore pytype warning about abstract methods, since this should always
     # be a subclass of ProjectorModel which has these implemented.
     projector = self._model_factory(**config.get("proj_kw", {}))  # pytype: disable=not-instantiable
     train_inputs = dataset.indexed_examples
     # TODO(lit-dev): remove 'dataset_name' from caching logic so we don't need
     # to track it here or elsewhere.
     train_outputs = list(
         model.predict_with_metadata(
             train_inputs, dataset_name=config.get("dataset_name")))
     logging.info("Creating new projection instance on %d points",
                  len(train_inputs))
     return ProjectionInterpreter(model,
                                  train_inputs,
                                  train_outputs,
                                  projector=projector,
                                  field_name=config["field_name"],
                                  name=name)
예제 #19
0
  def run(self,
          inputs: List[JsonDict],
          model: lit_model.Model,
          dataset: lit_dataset.Dataset,
          model_outputs: Optional[List[JsonDict]] = None,
          config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
    del dataset
    del config

    field_map = self.find_fields(model)

    # Run model, if needed.
    if model_outputs is None:
      model_outputs = list(model.predict(inputs))
    assert len(model_outputs) == len(inputs)

    return [
        self._run_single(ex, mo, field_map)
        for ex, mo in zip(inputs, model_outputs)
    ]
예제 #20
0
 def _train_instance(self, model: lit_model.Model,
                     dataset: lit_dataset.Dataset, config: Dict[Text, Any],
                     name: Text) -> ProjectionInterpreter:
   # Ignore pytype warning about abstract methods, since this should always
   # be a subclass of ProjectorModel which has these implemented.
   projector = self._model_factory(**config.get("proj_kw", {}))  # pytype: disable=not-instantiable
   # TODO(lit-dev): recomputing hashes here is a bit wasteful - consider
   # creating an 'IndexedDataset' class in the server, and passing that
   # around so that components can access IndexedInputs directly.
   train_inputs = caching.add_hashes_to_input(dataset.examples)
   # TODO(lit-dev): remove 'dataset_name' from caching logic so we don't need
   # to track it here or elsewhere.
   train_outputs = list(
       model.predict_with_metadata(
           train_inputs, dataset_name=config.get("dataset_name")))
   logging.info("Creating new projection instance on %d points",
                len(train_inputs))
   return ProjectionInterpreter(
       model,
       train_inputs,
       train_outputs,
       projector=projector,
       field_name=config["field_name"],
       name=name)
예제 #21
0
    def get_salience_result(self, model_input: JsonDict,
                            model: lit_model.Model, model_output: JsonDict,
                            grad_fields: List[Text]):
        result = {}

        output_spec = model.output_spec()
        # We ensure that the embedding and gradient class fields are present in the
        # model's input spec in find_fields().
        embeddings_fields = [
            cast(types.TokenGradients, output_spec[grad_field]).grad_for
            for grad_field in grad_fields
        ]

        # The gradient class input is used to specify the target class of the
        # gradient calculation (if unspecified, this option defaults to the argmax,
        # which could flip between interpolated inputs).
        grad_class_key = cast(types.TokenGradients,
                              output_spec[grad_fields[0]]).grad_target
        # TODO(b/168042999): Add option to specify the class to explain in the UI.
        grad_class = model_output[grad_class_key]

        interpolated_inputs = {}
        all_embeddings = []
        all_baselines = []
        for embed_field in embeddings_fields:
            # <float32>[num_tokens, emb_size]
            embeddings = np.array(model_output[embed_field])
            all_embeddings.append(embeddings)

            # Starts with baseline of zeros. <float32>[num_tokens, emb_size]
            baseline = self.get_baseline(embeddings)
            all_baselines.append(baseline)

            # Get interpolated inputs from baseline to original embedding.
            # <float32>[interpolation_steps, num_tokens, emb_size]
            interpolated_inputs[embed_field] = self.get_interpolated_inputs(
                baseline, embeddings, self.interpolation_steps)

        # Create model inputs and populate embedding field(s).
        inputs_with_embeds = []
        for i in range(self.interpolation_steps):
            input_copy = model_input.copy()
            # Interpolates embeddings for all inputs simultaneously.
            for embed_field in embeddings_fields:
                # <float32>[num_tokens, emb_size]
                input_copy[embed_field] = interpolated_inputs[embed_field][i]
                input_copy[grad_class_key] = grad_class

            inputs_with_embeds.append(input_copy)
        embed_outputs = model.predict(inputs_with_embeds)

        # Create list with concatenated gradients for each interpolate input.
        gradients = []
        for o in embed_outputs:
            # <float32>[total_num_tokens, emb_size]
            interp_gradients = np.concatenate(
                [o[field] for field in grad_fields])
            gradients.append(interp_gradients)
        # <float32>[interpolation_steps, total_num_tokens, emb_size]
        path_gradients = np.stack(gradients, axis=0)

        # Calculate integral
        # <float32>[total_num_tokens, emb_size]
        integral = self.estimate_integral(path_gradients)

        # <float32>[total_num_tokens, emb_size]
        concat_embeddings = np.concatenate(all_embeddings)

        # <float32>[total_num_tokens, emb_size]
        concat_baseline = np.concatenate(all_baselines)

        # <float32>[total_num_tokens, emb_size]
        integrated_gradients = integral * (np.array(concat_embeddings) -
                                           np.array(concat_baseline))
        # Dot product of integral values and (embeddings - baseline).
        # <float32>[total_num_tokens]
        attributions = np.sum(integrated_gradients, axis=-1)

        # TODO(b/168042999): Make normalization customizable in the UI.
        # <float32>[total_num_tokens]
        scores = citrus_utils.normalize_scores(attributions)

        for grad_field in grad_fields:
            # Format as salience map result.
            token_field = cast(types.TokenGradients,
                               output_spec[grad_field]).align
            tokens = model_output[token_field]

            # Only use the scores that correspond to the tokens in this grad_field.
            # The gradients for all input embeddings were concatenated in the order
            # of the grad fields, so they can be sliced out in the same order.
            sliced_scores = scores[:len(
                tokens)]  # <float32>[num_tokens in field]
            scores = scores[len(tokens):]  # <float32>[num_remaining_tokens]

            assert len(tokens) == len(sliced_scores)
            result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores)
        return result
예제 #22
0
파일: pdp.py 프로젝트: PAIR-code/lit
    def run(self,
            inputs: List[types.JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[types.JsonDict]] = None,
            config: Optional[types.JsonDict] = None):
        """Create PDP chart info using provided inputs.

    Args:
      inputs: sequence of inputs, following model.input_spec()
      model: optional model to use to generate new examples.
      dataset: dataset which the current examples belong to.
      model_outputs: optional precomputed model outputs
      config: optional runtime config.

    Returns:
      a dict of alternate feature values to model outputs. The model
      outputs will be a number for regression models and a list of numbers for
      multiclass models.
    """

        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('PDP did not find any supported output fields.')
            return None

        assert 'feature' in config, 'No feature to test provided'
        feature = config['feature']
        provided_range = config['range'] if 'range' in config else []
        edited_outputs = {}
        for pred_key in pred_keys:
            edited_outputs[pred_key] = {}

        # If a range was provided, use that to create the possible values.
        vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10)
                        if len(provided_range) == 2 else self.get_vals_to_test(
                            feature, dataset))

        # If no specific inputs provided, use the entire dataset.
        inputs_to_use = inputs if inputs else dataset.examples

        # For each alternate value for a given feature.
        for new_val in vals_to_test:
            # Create copies of all provided inputs with the value replaced.
            edited_inputs = []
            for inp in inputs_to_use:
                edited_input = copy.deepcopy(inp)
                edited_input[feature] = new_val
                edited_inputs.append(edited_input)

            # Run prediction on the altered inputs.
            outputs = list(model.predict(edited_inputs))

            # Store the mean of the prediction for the alternate value.
            for pred_key in pred_keys:
                numeric = isinstance(model.output_spec()[pred_key],
                                     types.RegressionScore)
                if numeric:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs])
                else:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs], axis=0)

        return edited_outputs
예제 #23
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT
        kernel_width = int(
            config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT
        mask_string = config[MASK_KEY] if config else MASK_DEFAULT
        num_samples = int(
            config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT
        seed = config[SEED_KEY] if config else SEED_DEFAULT

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(_predict_fn,
                                           model=model,
                                           original_example=input_,
                                           pred_key=pred_key)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=
                    class_to_explain,  # Index of the class to explain.
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
예제 #24
0
 def is_compatible(self, model: lit_model.Model):
     text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
     pred_keys = utils.find_spec_keys(
         model.output_spec(),
         (types.MulticlassPreds, types.RegressionScore))
     return len(text_keys) and len(pred_keys)
예제 #25
0
 def find_fields(self, model: lit_model.Model) -> Dict[str, List[str]]:
   src_fields = utils.find_spec_keys(model.input_spec(), types.TextSegment)
   gen_fields = utils.find_spec_keys(
       model.output_spec(),
       (types.GeneratedText, types.GeneratedTextCandidates))
   return {f: src_fields for f in gen_fields}
예제 #26
0
 def is_compatible(self, model: lit_model.Model) -> bool:
     input_spec = model.input_spec()
     output_spec = model.output_spec()
     return find_supported_fields(input_spec, output_spec) is not None
예제 #27
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Runs the component, given a model and input(s)."""
        input_spec = model.input_spec()
        output_spec = model.output_spec()

        if not inputs:
            return []

        # Find all fields required for the interpretation.
        supported_fields = find_supported_fields(input_spec, output_spec)
        if supported_fields is None:
            return
        grad_field_key = supported_fields.grad_field_key
        image_field_key = supported_fields.image_field_key
        grad_target_field_key = supported_fields.grad_target_field_key
        preds_field_key = supported_fields.preds_field_key
        multiclass = grad_target_field_key is not None

        # Determine the shape of gradients by calling the model with a single input
        # and extracting the shape from the gradient output.
        model_output = list(model.predict([inputs[0]]))
        grad_shape = model_output[0][grad_field_key].shape

        # If it is a multiclass model, find the labels with respect to which we
        # should compute the gradients.
        if multiclass:
            # Get class labels.
            label_vocab = list(
                cast(types.MulticlassPreds,
                     output_spec[preds_field_key]).vocab)

            # Run the model in order to find the gradient target labels.
            outputs = list(model.predict(inputs))
            grad_target_labels = []
            for model_input, model_output in zip(inputs, outputs):
                if model_input.get(grad_target_field_key) is not None:
                    grad_target_labels.append(
                        model_input[grad_target_field_key])
                else:
                    max_idx = int(np.argmax(model_output[preds_field_key]))
                    grad_target_labels.append(label_vocab[max_idx])
        else:
            grad_target_labels = [None] * len(inputs)

        saliency_object = self.get_saliency_object()
        extra_saliency_params = self.get_extra_saliency_params(config)
        all_results = []
        for example, grad_target_label in zip(inputs, grad_target_labels):
            result = {}
            image_str = example[image_field_key]
            saliency_input = image_utils.convert_image_str_to_array(
                image_str=image_str, shape=grad_shape)
            call_model_func = get_call_model_func(
                model=model,
                model_input=example,
                image_field_key=image_field_key,
                grad_field_key=grad_field_key,
                grad_target_field_key=grad_target_field_key,
                grad_target_label=grad_target_label)
            attribution = self.make_saliency_call(
                saliency_object=saliency_object,
                x_value=saliency_input,
                call_model_function=call_model_func,
                extra_saliency_params=extra_saliency_params)
            if attribution.ndim == 3:
                attribution = attribution.sum(axis=2)
            viz_params = self.get_visualization_params()
            overlaid_image = image_utils.overlay_pixel_saliency(
                image_str, attribution, **viz_params)
            result[grad_field_key] = image_utils.convert_pil_to_image_str(
                overlaid_image)
            all_results.append(result)
        return all_results
예제 #28
0
    def run_with_metadata(
            self,
            indexed_inputs: Sequence[IndexedInput],
            model: lit_model.Model,
            dataset: lit_dataset.IndexedDataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Runs the TCAV method given the params in the inputs and config.

    Args:
      indexed_inputs: all examples in the dataset, in the indexed input format.
      model: the model being explained.
      dataset: the dataset which the current examples belong to.
      model_outputs: optional model outputs from calling model.predict(inputs).
      config: a config which should specify: {
          'concept_set_ids': [list of ids to use in concept set]
          'class_to_explain': [gradient class to explain],
          'grad_layer': [the Gradient field key of the layer to explain],
          'random_state': [an optional seed to make outputs deterministic]
          'dataset_name': [the name of the dataset (used for caching)]
          'test_size': [Percentage of the example set to use in the LM test set]
          'negative_set_ids': [optional list of ids to use as negative set] }

    Returns:
      A JsonDict containing the TCAV scores, directional derivatives,
      statistical test p-values, and LM accuracies.
    """
        config = TCAVConfig(**config)
        # TODO(b/171513556): get these from the Dataset object once indices are
        # available there.
        dataset_examples = indexed_inputs

        # Get this layer's output spec keys for gradients and embeddings.
        grad_layer = config.grad_layer
        output_spec = model.output_spec()
        emb_layer = cast(types.Gradients, output_spec[grad_layer]).grad_for

        # Get the class that the gradients were computed for.
        grad_class_key = cast(types.Gradients,
                              output_spec[grad_layer]).grad_target_field_key

        ids_set = set(config.concept_set_ids)
        concept_set = [ex for ex in dataset_examples if ex['id'] in ids_set]
        non_concept_set = [
            ex for ex in dataset_examples if ex['id'] not in ids_set
        ]

        # Get outputs using model.predict().
        dataset_outputs = list(
            model.predict_with_metadata(dataset_examples,
                                        dataset_name=config.dataset_name))

        if config.negative_set_ids:
            negative_ids_set = set(config.negative_set_ids)
            negative_set = [
                ex for ex in dataset_examples if ex['id'] in negative_ids_set
            ]
            return self._run_relative_tcav(grad_layer, emb_layer,
                                           grad_class_key, concept_set,
                                           negative_set, dataset_outputs,
                                           model, config)
        else:
            return self._run_default_tcav(grad_layer, emb_layer,
                                          grad_class_key, concept_set,
                                          non_concept_set, dataset_outputs,
                                          model, config)
예제 #29
0
    def run(
            self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None,
            kernel_width: int = 25,  # TODO(lit-dev): make configurable in UI.
            mask_string:
        str = '[MASK]',  # TODO(lit-dev): make configurable in UI.
            num_samples: int = 256,  # TODO(lit-dev): make configurable in UI.
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(model.output_spec(),
                                         types.MulticlassPreds)
        if not pred_keys:
            logging.warning(
                'LIME did not find a multi-class predictions field.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key])
        label_names = pred_spec.vocab

        # Create a LIME text explainer instance.
        explainer = lime_text.LimeTextExplainer(
            class_names=label_names,
            split_expression=str.split,
            kernel_width=kernel_width,
            mask_string=mask_string,  # This is the string used to mask words.
            bow=False
        )  # bow=False masks inputs, instead of deleting them entirely.

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                logging.info('Explaining: %s', input_string)

                # Use the number of words as the number of features.
                num_features = len(input_string.split())

                def _predict_proba(strings: List[Text]):
                    """Given raw strings, return probabilities. Used by `explainer`."""
                    input_examples = [
                        new_example(input_, text_key, s) for s in strings
                    ]
                    model_outputs = model.predict(input_examples)
                    probs = np.array(
                        [output[pred_key] for output in model_outputs])
                    return probs  # <float32>[len(strings), num_labels]

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = explainer.explain_instance(
                    input_string,
                    _predict_proba,
                    num_features=num_features,
                    num_samples=num_samples)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation_to_array(explanation)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
예제 #30
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None,
                 num_examples: int = 1) -> List[JsonDict]:
        """Use gradient to find/substitute the token with largest impact on loss."""
        del dataset  # Unused.

        assert model is not None, "Please provide a model for this generator."
        logging.info(r"W3lc0m3 t0 H0tFl1p \o/")
        logging.info("Original example: %r", example)

        # Find gradient fields to use for HotFlip
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec)
        logging.info("Found gradient fields for HotFlip use: %s",
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            logging.info("No gradient fields found. Cannot use HotFlip. :-(")
            return []  # Cannot generate examples without gradients.

        # Get model outputs.
        logging.info(
            "Performing a forward/backward pass on the input example.")
        model_output = model.predict_single(example)
        logging.info(model_output.keys())

        # Get model word embeddings and vocab.
        inv_vocab, embed = model.get_embedding_table()
        assert len(
            inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch."
        logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab),
                     embed.shape)

        # Perform a flip in each sequence for which we have gradients (separately).
        # Each sequence may give rise to multiple new examples, depending on how
        # many words we flip.
        # TODO(lit-team): make configurable how many new examples are desired.
        # TODO(lit-team): use only 1 sequence as input (configurable in UI).
        new_examples = []
        for grad_field in grad_fields:

            # Get the tokens and their gradient vectors.
            token_field = output_spec[grad_field].align  # pytype: disable=attribute-error
            tokens = model_output[token_field]
            grads = model_output[grad_field]

            # Identify the token with the largest gradient norm.
            # TODO(lit-team): consider normalizing across all grad fields or just
            # across each one individually.
            grad_norm = np.linalg.norm(grads, axis=1)
            grad_norm = grad_norm / np.sum(
                grad_norm)  # Match grad attribution value.

            # Get a list of indices of input tokens, sorted by norm, highest first.
            sorted_by_grad_norm = np.argsort(grad_norm)[::-1]

            for i in range(min(num_examples, len(tokens))):
                token_id = sorted_by_grad_norm[i]
                logging.info(
                    "Selected token: %s (pos=%d) with gradient norm %f",
                    tokens[token_id], token_id, grad_norm[token_id])
                token_grad = grads[token_id]

                # Take dot product with all word embeddings. Get largest value.
                scores = np.dot(embed, token_grad)

                # TODO(lit-team): Can add criteria to the winner e.g. cosine distance.
                winner = np.argmax(scores)
                logging.info(
                    "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)",
                    tokens[token_id], token_id, i, inv_vocab[winner], winner)

                # Create a new input to the model.
                # TODO(iftenney, bastings): enforce somewhere that this field has the
                # same name in the input and output specs.
                input_token_field = token_field
                input_text_field = input_spec[input_token_field].parent  # pytype: disable=attribute-error
                new_example = copy.deepcopy(example)
                modified_tokens = copy.copy(tokens)
                modified_tokens[token_id] = inv_vocab[winner]
                new_example[input_token_field] = modified_tokens
                # TODO(iftenney, bastings): call a model-provided detokenizer here?
                # Though in general tokenization isn't invertible and it's possible for
                # HotFlip to produce wordpiece sequences that don't correspond to any
                # input string.
                new_example[input_text_field] = " ".join(modified_tokens)

                # Predict a new label for this example.
                new_output = model.predict_single(new_example)

                # Update label if multi-class prediction.
                # TODO(lit-dev): provide a general system for handling labels on
                # generated examples.
                for pred_key, pred_type in model.output_spec().items():
                    if isinstance(pred_type, types.MulticlassPreds):
                        probabilities = new_output[pred_key]
                        prediction = np.argmax(probabilities)
                        label_key = output_spec[pred_key].parent
                        label_names = output_spec[pred_key].vocab
                        new_label = label_names[prediction]
                        new_example[label_key] = new_label
                        logging.info("Updated example with new label: %s",
                                     new_label)

                new_examples.append(new_example)

        return new_examples