예제 #1
0
    def __init__(self, encoder: encoders.BertEncoderWithOffsets,
                 classifier: edge_predictor.SingleEdgePredictor):
        self.encoder = encoder
        self.classifier = classifier

        embs_field = utils.find_spec_keys(self.encoder.output_spec(),
                                          lit_types.TokenEmbeddings)[0]
        offset_field = utils.find_spec_keys(self.encoder.output_spec(),
                                            lit_types.SubwordOffsets)[0]
        self.extractor_kw = dict(edge_field='coref',
                                 embs_field=embs_field,
                                 offset_field=offset_field)
예제 #2
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Replace words based on replacement list."""
        del model  # Unused.

        subs_string = config.get('subs') if config else None
        if subs_string:
            replacements = self.parse_subs_string(subs_string)
        else:
            replacements = self.default_replacements

        new_examples = []
        # TODO(lit-dev): move this to generate_all(), so we read the spec once
        # instead of on every example.
        text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        for text_key in text_keys:
            text_data = example[text_key]
            token_spans = map(lambda x: x.span(),
                              self.tokenization_pattern.finditer(text_data))
            for new_val in self.generate_counterfactuals(
                    text_data, token_spans, replacements):
                new_example = copy.deepcopy(example)
                new_example[text_key] = new_val
                new_examples.append(new_example)

        return new_examples
예제 #3
0
 def output_spec(self):
     ret = glue_models.STSBModel.output_spec(self)
     token_gradient_keys = utils.find_spec_keys(ret,
                                                lit_types.TokenGradients)
     for k in token_gradient_keys:
         ret.pop(k, None)
     return ret
예제 #4
0
    def find_fields(self, input_spec: Spec, output_spec: Spec) -> List[Text]:
        # Find TokenGradients fields
        grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients)

        # Check that these are aligned to Tokens fields
        aligned_fields = []
        for f in grad_fields:
            tokens_field = output_spec[f].align  # pytype: disable=attribute-error
            assert tokens_field in output_spec
            assert isinstance(output_spec[tokens_field], types.Tokens)

            embeddings_field = output_spec[f].grad_for
            grad_class_key = output_spec[f].grad_target
            if embeddings_field is not None and grad_class_key is not None:
                assert embeddings_field in input_spec
                assert isinstance(input_spec[embeddings_field],
                                  types.TokenEmbeddings)
                assert embeddings_field in output_spec
                assert isinstance(output_spec[embeddings_field],
                                  types.TokenEmbeddings)

                assert grad_class_key in input_spec
                assert grad_class_key in output_spec
                aligned_fields.append(f)
            else:
                logging.info('Skipping %s since embeddings field not found.',
                             str(f))
        return aligned_fields
예제 #5
0
  def run(self,
          inputs: List[JsonDict],
          dataset: lit_dataset.Dataset,
          config: Optional[JsonDict] = None):
    """Run generation on a set of inputs.

    Args:
      inputs: sequence of inputs, following dataset.spec()
      dataset: dataset, used to access dataset.spec()
      config: additional runtime options

    Returns:
      list of list of new generated inputs, following dataset.spec()
    """
    all_outputs = [[] for _ in inputs]

    # Find text fields
    text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment)
    # TODO(lit-team): configure a subset of fields to operate on
    candidates_by_field = {}
    for field_name in text_fields:
      texts = [ex[field_name] for ex in inputs]
      candidates_by_field[field_name] = self.generate_from_texts(texts)
    # Generate by substituting in each field.
    # TODO(lit-team): substitute on a combination of fields?
    for field_name in candidates_by_field:
      candidates = candidates_by_field[field_name]
      for i, ex in enumerate(inputs):
        for candidate in candidates[i]:
          new_ex = utils.copy_and_update(ex, {field_name: candidate})
          all_outputs[i].append(new_ex)
    return all_outputs
예제 #6
0
파일: annotators.py 프로젝트: PAIR-code/lit
    def annotate(self,
                 inputs: List[JsonDict],
                 dataset: lit_dataset.Dataset,
                 dataset_spec_to_annotate: Optional[types.Spec] = None):
        if len(self._annotator_model.input_spec().items()) != 1:
            raise ValueError(
                'Annotator model provided to PerFieldAnnotator does not '
                'operate on a single field')

        datasets = {}
        for input_name, input_type in self._annotator_model.input_spec().items(
        ):
            # Do remap of inputs based on input name needed by annotator.
            ds_keys = utils.find_spec_keys(dataset.spec(), type(input_type))
            for ds_key in ds_keys:
                temp_ds = lit_dataset.Dataset(examples=inputs, base=dataset)
                datasets[ds_key] = temp_ds.remap({ds_key: input_name})

        for ds_key, ds in datasets.items():
            outputs = self._annotator_model.predict(ds.examples)
            for output_name, output_type in self._annotator_model.output_spec(
            ).items():
                # Update dataset spec with new annotated field.
                field_name = f'{self._name}:{output_name}:{ds_key}'
                if dataset_spec_to_annotate:
                    dataset_spec_to_annotate[field_name] = attr.evolve(
                        output_type, annotated=True)

                # Update all examples with annotator output.
                for example, output in zip(inputs, outputs):
                    example[field_name] = output[output_name]
예제 #7
0
    def _get_tokens_and_gradients(self, input_spec: JsonDict,
                                  output_spec: JsonDict, output: JsonDict,
                                  selected_fields: List[str]):
        """Returns a dictionary mapping token fields to tokens and gradients."""
        # Find selected token fields.
        input_spec_keys = set(utils.find_spec_keys(input_spec, types.Tokens))
        logging.info("input_spec_keys: %r", input_spec_keys)
        selected_input_spec_keys = list(input_spec_keys & set(selected_fields))
        logging.info("selected_input_spec_keys: %r", selected_input_spec_keys)
        token_fields = [
            key for key in selected_input_spec_keys
            if input_spec[key].is_compatible(output_spec.get(key))
        ]

        if len(token_fields) == 0:  # pylint: disable=g-explicit-length-test
            return {}

        ret = {}
        for token_field in token_fields:
            # Get tokens, token gradients and token embeddings.
            tokens = output[token_field]
            grad_fields = self.find_fields(output_spec, types.TokenGradients,
                                           token_field)
            assert grad_fields, (
                f"No gradients found for {token_field}. Cannot use HotFlip. :-("
            )
            assert len(grad_fields) == 1, (
                f"Multiple gradients found for {token_field}."
                f"Cannot use HotFlip. :-(")
            grads = output[grad_fields[0]] if grad_fields else None
            ret[token_field] = [tokens, grads]
        return ret
예제 #8
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        if not inputs: return

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LEMON requires text inputs.')
            return None
        logging.info('Found text fields for LEMON attribution: %s',
                     str(text_keys))

        pred_key = config['pred_key']
        output_probs = np.array([output[pred_key] for output in model_outputs])

        # Explain the input given counterfactuals.

        # Dict[field name -> interpretations]
        result = {}

        # Explain each text segment in the input, keeping the others constant.
        for text_key in text_keys:
            sentences = [item[text_key] for item in inputs]
            input_to_prediction = dict(zip(sentences, output_probs))

            input_string = sentences[0]
            counterfactuals = sentences[1:]

            # Remove duplicate counterfactuals.
            counterfactuals = list(set(counterfactuals))

            logging.info('Explaining: %s', input_string)

            predict_proba = make_predict_fn(input_to_prediction)

            # Perturbs the input string, gets model predictions, fits linear model.
            explanation = lemon.explain(
                input_string,
                counterfactuals,
                predict_proba,
                class_to_explain=config['class_to_explain'],
                lowercase_tokens=config['lowercase_tokens'])

            scores = np.array(explanation.feature_importance)

            # Normalize feature values.
            scores = citrus_utils.normalize_scores(scores)

            result[text_key] = dtypes.TokenSalience(input_string.split(),
                                                    scores)

        return [result]
예제 #9
0
 def test_find_spec_keys(self):
   spec = {
       "score": types.RegressionScore(),
       "scalar_foo": types.Scalar(),
       "text": types.TextSegment(),
       "emb_0": types.Embeddings(),
       "emb_1": types.Embeddings(),
       "tokens": types.Tokens(),
       "generated_text": types.GeneratedText(),
   }
   self.assertEqual(["score"], utils.find_spec_keys(spec,
                                                    types.RegressionScore))
   self.assertEqual(["text", "tokens", "generated_text"],
                    utils.find_spec_keys(spec,
                                         (types.TextSegment, types.Tokens)))
   self.assertEqual(["emb_0", "emb_1"],
                    utils.find_spec_keys(spec, types.Embeddings))
   self.assertEqual([], utils.find_spec_keys(spec, types.AttentionHeads))
   # Check subclasses
   self.assertEqual(
       list(spec.keys()), utils.find_spec_keys(spec, types.LitType))
   self.assertEqual(["text", "generated_text"],
                    utils.find_spec_keys(spec, types.TextSegment))
   self.assertEqual(["score", "scalar_foo"],
                    utils.find_spec_keys(spec, types.Scalar))
예제 #10
0
def find_supported_fields(input_spec: Spec,
                          output_spec: Spec) -> Optional[SupportedFields]:
    """Returns fields from the model specs that are needed for saliency ."""

    # Find all ImageGradients fields.
    grad_field_keys = lit_utils.find_spec_keys(output_spec,
                                               types.ImageGradients)
    # Models with more than one gradient field are not supported.
    if len(grad_field_keys) > 1 or not grad_field_keys:
        return None
    grad_field_key = grad_field_keys[0]
    grad_field_value = cast(types.ImageGradients, output_spec[grad_field_key])

    # Find image fields that correspond to grad_field.
    image_field_key = grad_field_value.align
    assert isinstance(input_spec[image_field_key], types.ImageBytes)

    # Find gradient target fields in the input if it is a multiclass
    # classification model. The value of None means that it is a regression or
    # single class classification model.
    multiclass = grad_field_value.grad_target_field_key is not None
    if multiclass:
        grad_target_field_key = grad_field_value.grad_target_field_key
        assert isinstance(input_spec[grad_target_field_key],
                          types.CategoryLabel)
    else:
        grad_target_field_key = None

    # Find prediction field names.
    if multiclass:
        preds_field_keys = lit_utils.find_spec_keys(output_spec,
                                                    types.MulticlassPreds)
    else:
        preds_field_keys = lit_utils.find_spec_keys(output_spec,
                                                    types.RegressionScore)
    # Models with more than one prediction field are not supported.
    if len(preds_field_keys) > 1 or not preds_field_keys:
        return None
    preds_field_key = preds_field_keys[0]

    return SupportedFields(grad_field_key=grad_field_key,
                           image_field_key=image_field_key,
                           grad_target_field_key=grad_target_field_key,
                           preds_field_key=preds_field_key)
예제 #11
0
    def find_fields(self, output_spec: Spec) -> List[Text]:
        # Find TokenGradients fields
        grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients)

        # Check that these are aligned to Tokens fields
        for f in grad_fields:
            tokens_field = output_spec[f].align  # pytype: disable=attribute-error
            assert tokens_field in output_spec
            assert isinstance(output_spec[tokens_field], types.Tokens)
        return grad_fields
예제 #12
0
파일: index.py 프로젝트: oceanfly/lit
    def _fill_indices(self, model_name, dataset_name):
        """Create all indices for a single model."""
        model = self._models.get(model_name)
        assert model is not None, "Invalid model name."
        examples = self.datasets[dataset_name].indexed_examples
        model_embeddings_names = utils.find_spec_keys(model.output_spec(),
                                                      lit_types.Embeddings)
        lookup_key = self._get_lookup_key(model_name, dataset_name)

        # If the model has no embeddings to extract, skip the following.
        if not model_embeddings_names:
            return

        # Load from file if it exists.
        for emb_name in model_embeddings_names:
            # Initialize the index object in self._indices with serialized index.
            self._init_index_from_file(model_name, dataset_name, emb_name)
        # Load example lookup dictionary from file.
        self._example_lookup[lookup_key] = self._load_lookup(lookup_key)

        # Identify which indices need to be initialized.
        embeddings_to_index = [
            emb_name for emb_name in model_embeddings_names if
            not self._is_index_initialized(model_name, dataset_name, emb_name)
        ]
        # Early exit if all embeddings are now initialized.
        if not embeddings_to_index:
            return

        # Cold start: Get embeddings for non-initialized settings.
        if self._initialize_new_indices:
            for res_ix, (result, example) in enumerate(
                    zip(model.predict_with_metadata(examples), examples)):
                for emb_name in embeddings_to_index:
                    index_key = self._get_index_key(model_name, dataset_name,
                                                    emb_name)
                    # Initialize saving in the first iteration.
                    if res_ix == 0:
                        file_path = self._get_index_path(index_key)
                        self._indices[index_key].on_disk_build(file_path)
                    index = self._indices.get(index_key)
                    assert index is not None, "Index needs to be created first."
                    # Each item has an incrementing ID res_ix.
                    self._indices[index_key].add_item(res_ix, result[emb_name])
                # Add item to lookup table.
                self._example_lookup[lookup_key][res_ix] = example["data"]

            # Create the trees from the indices - using 10 as recommended by doc.
            for emb_name in embeddings_to_index:
                index_key = self._get_index_key(model_name, dataset_name,
                                                emb_name)
                logging.info("Creating new index: %s", index_key)
                self._indices[index_key].build(10)
                index_size = self._indices[index_key].get_n_items()
                logging.info("Created new index with %s items.", index_size)
예제 #13
0
def symmetrize_edges(dataset: lit_dataset.Dataset) -> lit_dataset.Dataset:
    """Symmetrize edges by adding copies with span1 and span2 interchanged."""
    def _swap(edge):
        return lit_dtypes.EdgeLabel(edge.span2, edge.span1, edge.label)

    edge_fields = utils.find_spec_keys(dataset.spec(), lit_types.EdgeLabels)
    examples = []
    for ex in dataset.examples:
        new_ex = copy.copy(ex)
        for field in edge_fields:
            new_ex[field] += [_swap(edge) for edge in ex[field]]
        examples.append(new_ex)
    return lit_dataset.Dataset(dataset.spec(), examples)
예제 #14
0
파일: app.py 프로젝트: zhiyiZeng/lit
 def _warm_projections(self, interpreters: List[Text]):
   """Pre-compute UMAP/PCA projections with default arguments."""
   for model, model_info in self._info['models'].items():
     for dataset_name in model_info['datasets']:
       for field_name in utils.find_spec_keys(model_info['spec']['output'],
                                              types.Embeddings):
         config = dict(
             dataset_name=dataset_name,
             model_name=model,
             field_name=field_name,
             proj_kw={'n_components': 3})
         data = {'inputs': [], 'config': config}
         for interpreter_name in interpreters:
           _ = self._get_interpretations(
               data, model, dataset_name, interpreter=interpreter_name)
예제 #15
0
    def find_fields(self,
                    spec: Spec,
                    typ: Type[types.LitType],
                    align_field: Optional[Text] = None) -> List[Text]:
        # Find fields of provided 'typ'.
        fields = utils.find_spec_keys(spec, typ)

        if align_field is None:
            return fields

        # Only return fields that are aligned to fields with name specified by
        # align_field.
        return [
            f for f in fields if getattr(spec[f], "align", None) == align_field
        ]
예제 #16
0
  def generate(self,
               example: JsonDict,
               model: lit_model.Model,
               dataset: lit_dataset.Dataset,
               config: Optional[JsonDict] = None) -> List[JsonDict]:
    """Naively scramble all words in an example."""
    del model  # Unused.
    del config  # Unused.

    # TODO(lit-dev): move this to generate_all(), so we read the spec once
    # instead of on every example.
    text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
    new_example = copy.deepcopy(example)
    for text_key in text_keys:
      new_example[text_key] = self.scramble(example[text_key])
    return [new_example]
예제 #17
0
    def find_fields(
            self,
            output_spec: Spec,
            typ: Type[types.LitType],
            align_typ: Optional[Type[types.LitType]] = None) -> List[Text]:
        # Find fields of provided 'typ'.
        fields = utils.find_spec_keys(output_spec, typ)

        if align_typ is None:
            return fields

        # Check that these are aligned to fields of type 'align_typ'.
        for f in fields:
            align_field = output_spec[f].align  # pytype: disable=attribute-error
            assert align_field in output_spec, "Align field not in output_spec"
            assert isinstance(output_spec[align_field], align_typ)
        return fields
예제 #18
0
파일: index.py 프로젝트: oceanfly/lit
    def _create_empty_indices(self, model_name, dataset_name):
        """Create the empty indices for a model and dataset."""
        model = self._models[model_name]
        examples = self.datasets[dataset_name].indexed_examples
        model_embeddings_names = utils.find_spec_keys(model.output_spec(),
                                                      lit_types.Embeddings)
        if not model_embeddings_names:
            return

        # To first create an index, we need to know the shapes - peek at first ex.
        peeked_example = list(model.predict([examples[0]["data"]]))[0]
        for emb_name in model_embeddings_names:
            index_key = self._get_index_key(model_name, dataset_name, emb_name)
            emb_dimension = len(peeked_example[emb_name])
            assert self._indices.get(
                index_key) is None, "Index already exists."
            self._indices[index_key] = annoy.AnnoyIndex(
                emb_dimension, "euclidean")
예제 #19
0
파일: app.py 프로젝트: PAIR-code/lit
 def _warm_projections(self, interpreters: List[Text]):
     """Pre-compute UMAP/PCA projections with default arguments."""
     for model, model_info in self._info['models'].items():
         for dataset_name in model_info['datasets']:
             embedding_fields = utils.find_spec_keys(
                 model_info['spec']['output'], types.Embeddings)
             # Only warm-start on the first embedding field, since if models return
             # many different embeddings this can take a long time.
             for field_name in embedding_fields[:1]:
                 config = dict(dataset_name=dataset_name,
                               model_name=model,
                               field_name=field_name,
                               proj_kw={'n_components': 3})
                 data = {'inputs': [], 'config': config}
                 for interpreter_name in interpreters:
                     _ = self._get_interpretations(
                         data,
                         model,
                         dataset_name,
                         interpreter=interpreter_name)
예제 #20
0
파일: scrambler.py 프로젝트: PAIR-code/lit
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Naively scramble all words in an example.

    Note: Even if more than one field is to be scrambled, only a single example
    will be produced, unlike other generators which will produce multiple
    examples, one per field.

    Args:
      example: the example used for basis of generated examples.
      model: the model.
      dataset: the dataset.
      config: user-provided config properties.

    Returns:
      examples: a list of generated examples.
    """
        del model  # Unused.

        config = config or {}

        # If config key is missing, generate no examples.
        fields_to_scramble = list(config.get(FIELDS_TO_SCRAMBLE_KEY, []))
        if not fields_to_scramble:
            return []

        # TODO(lit-dev): move this to generate_all(), so we read the spec once
        # instead of on every example.
        text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        if not text_keys:
            return []

        text_keys = [key for key in text_keys if key in fields_to_scramble]

        new_example = copy.deepcopy(example)
        for text_key in text_keys:
            new_example[text_key] = self.scramble(example[text_key])
        return [new_example]
예제 #21
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Replace words based on replacement list."""
        del model  # Unused.

        ignore_casing = config.get('ignore_casing', True) if config else True
        subs_string = config.get('Substitutions') if config else None
        if subs_string:
            replacements = self.parse_subs_string(subs_string,
                                                  ignore_casing=ignore_casing)
        else:
            replacements = self.default_replacements

        # If replacements dictionary is empty, do not attempt to match.
        if not replacements:
            return []

        replacement_regex = self._get_replacement_pattern(
            replacements, ignore_casing=ignore_casing)

        new_examples = []
        # TODO(lit-dev): move this to generate_all(), so we read the spec once
        # instead of on every example.
        text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        for text_key in text_keys:
            text_data = example[text_key]
            for new_val in self.generate_counterfactuals(
                    text_data,
                    replacement_regex,
                    replacements,
                    ignore_casing=ignore_casing):
                new_example = copy.deepcopy(example)
                new_example[text_key] = new_val
                new_examples.append(new_example)

        return new_examples
예제 #22
0
    def _get_preds(self,
                   data,
                   model: Text,
                   dataset_name: Optional[Text] = None,
                   requested_types: Text = 'LitType',
                   **unused_kw):
        """Get model predictions.

    Args:
      data: data payload, containing 'inputs' field
      model: name of the model to run
      dataset_name: name of the active dataset
      requested_types: optional, comma-separated list of types to return

    Returns:
      List[JsonDict] containing requested fields of model predictions
    """
        preds = list(self._predict(data['inputs'], model, dataset_name))

        # Figure out what to return to the frontend.
        output_spec = self._get_spec(model)['output']
        requested_types = requested_types.split(',')
        logging.info('Requested types: %s', str(requested_types))
        ret_keys = []
        for t_name in requested_types:
            t_class = getattr(types, t_name, None)
            assert issubclass(
                t_class,
                types.LitType), f"Class '{t_name}' is not a valid LitType."
            ret_keys.extend(utils.find_spec_keys(output_spec, t_class))
        ret_keys = set(ret_keys)  # de-dupe

        # Return selected keys.
        logging.info('Will return keys: %s', str(ret_keys))
        # One record per input.
        ret = [utils.filter_by_keys(p, ret_keys.__contains__) for p in preds]
        return ret
예제 #23
0
    def run(self,
            inputs: List[JsonDict],
            dataset: lit_dataset.Dataset,
            config: Optional[JsonDict] = None):
        """Run generation on a set of inputs.

    Args:
      inputs: sequence of inputs, following dataset.spec()
      dataset: dataset, used to access dataset.spec()
      config: additional runtime options

    Returns:
      list of list of new generated inputs, following dataset.spec()
    """
        all_outputs = [[] for _ in inputs]

        config = config or {}

        # Find text fields.
        text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        # If config key is missing, backtranslate all text fields.
        fields_to_backtranslate = list(
            config.get(FIELDS_TO_BACKTRANSLATE_KEY, text_fields))
        candidates_by_field = {}
        for field_name in fields_to_backtranslate:
            texts = [ex[field_name] for ex in inputs]
            candidates_by_field[field_name] = self.generate_from_texts(texts)
        # Generate by substituting in each field.
        # TODO(lit-team): substitute on a combination of fields?
        for field_name in candidates_by_field:
            candidates = candidates_by_field[field_name]
            for i, ex in enumerate(inputs):
                for candidate in candidates[i]:
                    new_ex = utils.copy_and_update(ex, {field_name: candidate})
                    all_outputs[i].append(new_ex)
        return all_outputs
예제 #24
0
    def run(
            self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None,
            kernel_width: int = 25,  # TODO(lit-dev): make configurable in UI.
            mask_string:
        str = '[MASK]',  # TODO(lit-dev): make configurable in UI.
            num_samples: int = 256,  # TODO(lit-dev): make configurable in UI.
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(model.output_spec(),
                                         types.MulticlassPreds)
        if not pred_keys:
            logging.warning(
                'LIME did not find a multi-class predictions field.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key])
        label_names = pred_spec.vocab

        # Create a LIME text explainer instance.
        explainer = lime_text.LimeTextExplainer(
            class_names=label_names,
            split_expression=str.split,
            kernel_width=kernel_width,
            mask_string=mask_string,  # This is the string used to mask words.
            bow=False
        )  # bow=False masks inputs, instead of deleting them entirely.

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                logging.info('Explaining: %s', input_string)

                # Use the number of words as the number of features.
                num_features = len(input_string.split())

                def _predict_proba(strings: List[Text]):
                    """Given raw strings, return probabilities. Used by `explainer`."""
                    input_examples = [
                        new_example(input_, text_key, s) for s in strings
                    ]
                    model_outputs = model.predict(input_examples)
                    probs = np.array(
                        [output[pred_key] for output in model_outputs])
                    return probs  # <float32>[len(strings), num_labels]

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = explainer.explain_instance(
                    input_string,
                    _predict_proba,
                    num_features=num_features,
                    num_samples=num_samples)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation_to_array(explanation)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results
예제 #25
0
 def input_spec(self):
     ret = glue_models.STSBModel.input_spec(self)
     token_keys = utils.find_spec_keys(ret, lit_types.Tokens)
     for k in token_keys:
         ret.pop(k, None)
     return ret
예제 #26
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        config_defaults = {k: v.default for k, v in self.config_spec().items()}
        config = dict(config_defaults, **(config or {}))  # update and return

        provided_class_to_explain = int(config[CLASS_KEY])
        kernel_width = int(config[KERNEL_WIDTH_KEY])
        num_samples = int(config[NUM_SAMPLES_KEY])
        mask_string = (config[MASK_KEY])
        # pylint: disable=g-explicit-bool-comparison
        seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None
        # pylint: enable=g-explicit-bool-comparison

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(), (types.MulticlassPreds, types.RegressionScore,
                                  types.SparseMultilabelPreds))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None
        pred_key = config[TARGET_HEAD_KEY] or pred_keys[0]

        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(
                _predict_fn,
                model=model,
                original_example=input_,
                pred_key=pred_key,
                pred_type_info=model.output_spec()[pred_key])

            class_to_explain = get_class_to_explain(provided_class_to_explain,
                                                    model, pred_key, input_)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=class_to_explain,
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.TokenSalience(
                    input_string.split(), scores)

            all_results.append(result)

        return all_results
예제 #27
0
 def find_fields(self, model: lit_model.Model) -> Dict[str, List[str]]:
   src_fields = utils.find_spec_keys(model.input_spec(), types.TextSegment)
   gen_fields = utils.find_spec_keys(
       model.output_spec(),
       (types.GeneratedText, types.GeneratedTextCandidates))
   return {f: src_fields for f in gen_fields}
예제 #28
0
파일: pdp.py 프로젝트: PAIR-code/lit
    def run(self,
            inputs: List[types.JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[types.JsonDict]] = None,
            config: Optional[types.JsonDict] = None):
        """Create PDP chart info using provided inputs.

    Args:
      inputs: sequence of inputs, following model.input_spec()
      model: optional model to use to generate new examples.
      dataset: dataset which the current examples belong to.
      model_outputs: optional precomputed model outputs
      config: optional runtime config.

    Returns:
      a dict of alternate feature values to model outputs. The model
      outputs will be a number for regression models and a list of numbers for
      multiclass models.
    """

        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('PDP did not find any supported output fields.')
            return None

        assert 'feature' in config, 'No feature to test provided'
        feature = config['feature']
        provided_range = config['range'] if 'range' in config else []
        edited_outputs = {}
        for pred_key in pred_keys:
            edited_outputs[pred_key] = {}

        # If a range was provided, use that to create the possible values.
        vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10)
                        if len(provided_range) == 2 else self.get_vals_to_test(
                            feature, dataset))

        # If no specific inputs provided, use the entire dataset.
        inputs_to_use = inputs if inputs else dataset.examples

        # For each alternate value for a given feature.
        for new_val in vals_to_test:
            # Create copies of all provided inputs with the value replaced.
            edited_inputs = []
            for inp in inputs_to_use:
                edited_input = copy.deepcopy(inp)
                edited_input[feature] = new_val
                edited_inputs.append(edited_input)

            # Run prediction on the altered inputs.
            outputs = list(model.predict(edited_inputs))

            # Store the mean of the prediction for the alternate value.
            for pred_key in pred_keys:
                numeric = isinstance(model.output_spec()[pred_key],
                                     types.RegressionScore)
                if numeric:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs])
                else:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs], axis=0)

        return edited_outputs
예제 #29
0
 def is_compatible(self, model: lit_model.Model):
     text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
     pred_keys = utils.find_spec_keys(
         model.output_spec(),
         (types.MulticlassPreds, types.RegressionScore))
     return len(text_keys) and len(pred_keys)
예제 #30
0
    def run(
        self,
        inputs: List[JsonDict],
        model: lit_model.Model,
        dataset: lit_dataset.Dataset,
        model_outputs: Optional[List[JsonDict]] = None,
        config: Optional[JsonDict] = None,
    ) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""

        class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT
        kernel_width = int(
            config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT
        mask_string = config[MASK_KEY] if config else MASK_DEFAULT
        num_samples = int(
            config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT
        seed = config[SEED_KEY] if config else SEED_DEFAULT

        # Find keys of input (text) segments to explain.
        # Search in the input spec, since it's only useful to look at ones that are
        # used by the model.
        text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment)
        if not text_keys:
            logging.warning('LIME requires text inputs.')
            return None
        logging.info('Found text fields for LIME attribution: %s',
                     str(text_keys))

        # Find the key of output probabilities field(s).
        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('LIME did not find any supported output fields.')
            return None

        pred_key = pred_keys[
            0]  # TODO(lit-dev): configure which prob field to use.
        all_results = []

        # Explain each input.
        for input_ in inputs:
            # Dict[field name -> interpretations]
            result = {}
            predict_fn = functools.partial(_predict_fn,
                                           model=model,
                                           original_example=input_,
                                           pred_key=pred_key)

            # Explain each text segment in the input, keeping the others constant.
            for text_key in text_keys:
                input_string = input_[text_key]
                if not input_string:
                    logging.info('Could not explain empty string for %s',
                                 text_key)
                    continue
                logging.info('Explaining: %s', input_string)

                # Perturbs the input string, gets model predictions, fits linear model.
                explanation = lime.explain(
                    sentence=input_string,
                    predict_fn=functools.partial(predict_fn,
                                                 text_key=text_key),
                    # `class_to_explain` is ignored when predict_fn output is a scalar.
                    class_to_explain=
                    class_to_explain,  # Index of the class to explain.
                    num_samples=num_samples,
                    tokenizer=str.split,
                    mask_token=mask_string,
                    kernel=functools.partial(lime.exponential_kernel,
                                             kernel_width=kernel_width),
                    seed=seed)

                # Turn the LIME explanation into a list following original word order.
                scores = explanation.feature_importance
                # TODO(lit-dev): Move score normalization to the UI.
                scores = citrus_util.normalize_scores(scores)
                result[text_key] = dtypes.SalienceMap(input_string.split(),
                                                      scores)

            all_results.append(result)

        return all_results