コード例 #1
0
ファイル: static_preds.py プロジェクト: byhqsr/PAIR-code-lit
  def __init__(self,
               inputs: lit_dataset.Dataset,
               preds: lit_dataset.Dataset,
               input_identifier_keys: Optional[List[str]] = None):
    """Build a static index.

    Args:
      inputs: a lit Dataset
      preds: a lit Dataset, parallel to inputs
      input_identifier_keys: (optional), list of keys to treat as identifiers
        for matching inputs. If None, will use all fields in inputs.spec()
    """
    self._output_spec = preds.spec()
    self._input_spec = inputs.spec()
    self.input_identifier_keys = input_identifier_keys or self._input_spec.keys(
    )
    # Filter to only the identifier keys
    self._input_spec = {
        k: self._input_spec[k] for k in self.input_identifier_keys
    }

    # Build the index for prediction lookups
    self._index = {
        self.key_fn(ex): pred
        for ex, pred in zip(inputs.examples, preds.examples)
    }
コード例 #2
0
def symmetrize_edges(dataset: lit_dataset.Dataset) -> lit_dataset.Dataset:
    """Symmetrize edges by adding copies with span1 and span2 interchanged."""
    def _swap(edge):
        return lit_dtypes.EdgeLabel(edge.span2, edge.span1, edge.label)

    edge_fields = utils.find_spec_keys(dataset.spec(), lit_types.EdgeLabels)
    examples = []
    for ex in dataset.examples:
        new_ex = copy.copy(ex)
        for field in edge_fields:
            new_ex[field] += [_swap(edge) for edge in ex[field]]
        examples.append(new_ex)
    return lit_dataset.Dataset(dataset.spec(), examples)
コード例 #3
0
 def _calculate_stats(self, dataset: lit_dataset.Dataset,
                      dataset_name: Text) -> None:
     # Iterate through all examples in the dataset and store column values
     # in individual lists to facilitate future computation.
     field_values = {}
     spec = dataset.spec()
     supported_fields = [
         name for name in spec if self._is_supported(spec[name])
     ]
     for example in dataset.examples:
         for field_name in supported_fields:
             if example[field_name] is None:
                 continue
             if field_name not in field_values:
                 field_values[field_name] = []
             field_values[field_name].append(example[field_name])
     # Compute the necessary statistics: standard deviation for scalar fields and
     # probability of having same value for categorical and categorical fields.
     field_stats = {}
     for field_name, values in field_values.items():
         field_spec = spec[field_name]
         if self._is_scalar(field_spec):
             field_stats[field_name] = self._calculate_std_dev(values)
         elif self._is_categorical(field_spec):
             field_stats[field_name] = self._calculate_categorical_prob(
                 values)
         else:
             assert False, 'Should never be reached.'
     # Cache the stats for the given dataset.
     self._datasets_stats[dataset_name] = field_stats
コード例 #4
0
  def run(self,
          inputs: List[JsonDict],
          dataset: lit_dataset.Dataset,
          config: Optional[JsonDict] = None):
    """Run generation on a set of inputs.

    Args:
      inputs: sequence of inputs, following dataset.spec()
      dataset: dataset, used to access dataset.spec()
      config: additional runtime options

    Returns:
      list of list of new generated inputs, following dataset.spec()
    """
    all_outputs = [[] for _ in inputs]

    # Find text fields
    text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment)
    # TODO(lit-team): configure a subset of fields to operate on
    candidates_by_field = {}
    for field_name in text_fields:
      texts = [ex[field_name] for ex in inputs]
      candidates_by_field[field_name] = self.generate_from_texts(texts)
    # Generate by substituting in each field.
    # TODO(lit-team): substitute on a combination of fields?
    for field_name in candidates_by_field:
      candidates = candidates_by_field[field_name]
      for i, ex in enumerate(inputs):
        for candidate in candidates[i]:
          new_ex = utils.copy_and_update(ex, {field_name: candidate})
          all_outputs[i].append(new_ex)
    return all_outputs
コード例 #5
0
ファイル: annotators.py プロジェクト: PAIR-code/lit
    def annotate(self,
                 inputs: List[JsonDict],
                 dataset: lit_dataset.Dataset,
                 dataset_spec_to_annotate: Optional[types.Spec] = None):
        if len(self._annotator_model.input_spec().items()) != 1:
            raise ValueError(
                'Annotator model provided to PerFieldAnnotator does not '
                'operate on a single field')

        datasets = {}
        for input_name, input_type in self._annotator_model.input_spec().items(
        ):
            # Do remap of inputs based on input name needed by annotator.
            ds_keys = utils.find_spec_keys(dataset.spec(), type(input_type))
            for ds_key in ds_keys:
                temp_ds = lit_dataset.Dataset(examples=inputs, base=dataset)
                datasets[ds_key] = temp_ds.remap({ds_key: input_name})

        for ds_key, ds in datasets.items():
            outputs = self._annotator_model.predict(ds.examples)
            for output_name, output_type in self._annotator_model.output_spec(
            ).items():
                # Update dataset spec with new annotated field.
                field_name = f'{self._name}:{output_name}:{ds_key}'
                if dataset_spec_to_annotate:
                    dataset_spec_to_annotate[field_name] = attr.evolve(
                        output_type, annotated=True)

                # Update all examples with annotator output.
                for example, output in zip(inputs, outputs):
                    example[field_name] = output[output_name]
コード例 #6
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Replace words based on replacement list."""
        del model  # Unused.

        subs_string = config.get('subs') if config else None
        if subs_string:
            replacements = self.parse_subs_string(subs_string)
        else:
            replacements = self.default_replacements

        new_examples = []
        # TODO(lit-dev): move this to generate_all(), so we read the spec once
        # instead of on every example.
        text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        for text_key in text_keys:
            text_data = example[text_key]
            token_spans = map(lambda x: x.span(),
                              self.tokenization_pattern.finditer(text_data))
            for new_val in self.generate_counterfactuals(
                    text_data, token_spans, replacements):
                new_example = copy.deepcopy(example)
                new_example[text_key] = new_val
                new_examples.append(new_example)

        return new_examples
コード例 #7
0
ファイル: app.py プロジェクト: PAIR-code/lit
 def _run_annotators(self,
                     dataset: lit_dataset.Dataset) -> lit_dataset.Dataset:
     datapoints = [dict(ex) for ex in dataset.examples]
     annotated_spec = dict(dataset.spec())
     for annotator in self._annotators:
         annotator.annotate(datapoints, dataset, annotated_spec)
     return lit_dataset.Dataset(base=dataset,
                                examples=datapoints,
                                spec=annotated_spec)
コード例 #8
0
 def run_with_metadata(self,
                       indexed_inputs: List[JsonDict],
                       model: lit_model.Model,
                       dataset: lit_dataset.Dataset,
                       model_outputs: Optional[List[JsonDict]] = None,
                       config: Optional[JsonDict] = None) -> List[JsonDict]:
     # TODO(lit-team): pre-compute this mapping in constructor?
     # This would require passing a model name to this function so we can
     # reference a pre-computed list.
     spec = model.spec()
     field_map = map_pred_keys(dataset.spec(), spec.output,
                               self.is_compatible)
     ret = []
     for pred_key, label_key in field_map.items():
         # Extract fields
         labels = [ex['data'][label_key] for ex in indexed_inputs]
         preds = [mo[pred_key] for mo in model_outputs]
         indices = [ex['id'] for ex in indexed_inputs]
         metas = [ex['meta'] for ex in indexed_inputs]
         # Compute metrics, as dict(str -> float)
         metrics = self.compute_with_metadata(
             labels,
             preds,
             label_spec=dataset.spec()[label_key],
             pred_spec=spec.output[pred_key],
             indices=indices,
             metas=metas,
             config=config.get(label_key) if config else None)
         # NaN is not a valid JSON value, so replace with None which will be
         # serialized as null.
         # TODO(lit-team): move this logic into serialize.py somewhere instead?
         metrics = {
             k: (v if not np.isnan(v) else None)
             for k, v in metrics.items()
         }
         # Format for frontend.
         ret.append({
             'pred_key': pred_key,
             'label_key': label_key,
             'metrics': metrics
         })
     return ret
コード例 #9
0
ファイル: metrics.py プロジェクト: PAIR-code/lit
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None):
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        spec = model.spec()
        field_map = map_pred_keys(dataset.spec(), spec.output,
                                  self.is_compatible)
        ret = []
        for pred_key, label_key in field_map.items():
            # Extract fields
            labels = [ex[label_key] for ex in inputs]
            preds = [mo[pred_key] for mo in model_outputs]
            # Compute metrics, as dict(str -> float)
            metrics = self.compute(
                labels,
                preds,
                label_spec=dataset.spec()[label_key],
                pred_spec=spec.output[pred_key],
                config=config.get(pred_key) if config else None)
            # NaN is not a valid JSON value, so replace with None which will be
            # serialized as null.
            # TODO(lit-team): move this logic into serialize.py somewhere instead?
            metrics = {
                k: (v if not np.isnan(v) else None)
                for k, v in metrics.items()
            }
            # Format for frontend.
            ret.append({
                'pred_key': pred_key,
                'label_key': label_key,
                'metrics': metrics
            })
        return ret
コード例 #10
0
ファイル: scrambler.py プロジェクト: byhqsr/PAIR-code-lit
  def generate(self,
               example: JsonDict,
               model: lit_model.Model,
               dataset: lit_dataset.Dataset,
               config: Optional[JsonDict] = None) -> List[JsonDict]:
    """Naively scramble all words in an example."""
    del model  # Unused.
    del config  # Unused.

    # TODO(lit-dev): move this to generate_all(), so we read the spec once
    # instead of on every example.
    text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
    new_example = copy.deepcopy(example)
    for text_key in text_keys:
      new_example[text_key] = self.scramble(example[text_key])
    return [new_example]
コード例 #11
0
ファイル: metrics.py プロジェクト: PAIR-code/lit
 def run(self,
         inputs: List[JsonDict],
         model: lit_model.Model,
         dataset: lit_dataset.Dataset,
         model_outputs: Optional[List[JsonDict]] = None,
         config: Optional[JsonDict] = None):
     # Get margin for each input for each pred key and add them to a config dict
     # to pass to the wrapped metrics.
     field_map = map_pred_keys(dataset.spec(),
                               model.spec().output, self.is_compatible)
     margin_config = {}
     for pred_key in field_map:
         field_config = config.get(pred_key) if config else None
         margins = [
             get_margin_for_input(field_config, inp) for inp in inputs
         ]
         margin_config[pred_key] = margins
     return self._metrics.run(inputs, model, dataset, model_outputs,
                              margin_config)
コード例 #12
0
ファイル: scrambler.py プロジェクト: PAIR-code/lit
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Naively scramble all words in an example.

    Note: Even if more than one field is to be scrambled, only a single example
    will be produced, unlike other generators which will produce multiple
    examples, one per field.

    Args:
      example: the example used for basis of generated examples.
      model: the model.
      dataset: the dataset.
      config: user-provided config properties.

    Returns:
      examples: a list of generated examples.
    """
        del model  # Unused.

        config = config or {}

        # If config key is missing, generate no examples.
        fields_to_scramble = list(config.get(FIELDS_TO_SCRAMBLE_KEY, []))
        if not fields_to_scramble:
            return []

        # TODO(lit-dev): move this to generate_all(), so we read the spec once
        # instead of on every example.
        text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        if not text_keys:
            return []

        text_keys = [key for key in text_keys if key in fields_to_scramble]

        new_example = copy.deepcopy(example)
        for text_key in text_keys:
            new_example[text_key] = self.scramble(example[text_key])
        return [new_example]
コード例 #13
0
ファイル: word_replacer.py プロジェクト: zhiyiZeng/lit
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Replace words based on replacement list."""
        del model  # Unused.

        ignore_casing = config.get('ignore_casing', True) if config else True
        subs_string = config.get('Substitutions') if config else None
        if subs_string:
            replacements = self.parse_subs_string(subs_string,
                                                  ignore_casing=ignore_casing)
        else:
            replacements = self.default_replacements

        # If replacements dictionary is empty, do not attempt to match.
        if not replacements:
            return []

        replacement_regex = self._get_replacement_pattern(
            replacements, ignore_casing=ignore_casing)

        new_examples = []
        # TODO(lit-dev): move this to generate_all(), so we read the spec once
        # instead of on every example.
        text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        for text_key in text_keys:
            text_data = example[text_key]
            for new_val in self.generate_counterfactuals(
                    text_data,
                    replacement_regex,
                    replacements,
                    ignore_casing=ignore_casing):
                new_example = copy.deepcopy(example)
                new_example[text_key] = new_val
                new_examples.append(new_example)

        return new_examples
コード例 #14
0
    def run(self,
            inputs: List[JsonDict],
            dataset: lit_dataset.Dataset,
            config: Optional[JsonDict] = None):
        """Run generation on a set of inputs.

    Args:
      inputs: sequence of inputs, following dataset.spec()
      dataset: dataset, used to access dataset.spec()
      config: additional runtime options

    Returns:
      list of list of new generated inputs, following dataset.spec()
    """
        all_outputs = [[] for _ in inputs]

        config = config or {}

        # Find text fields.
        text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        # If config key is missing, backtranslate all text fields.
        fields_to_backtranslate = list(
            config.get(FIELDS_TO_BACKTRANSLATE_KEY, text_fields))
        candidates_by_field = {}
        for field_name in fields_to_backtranslate:
            texts = [ex[field_name] for ex in inputs]
            candidates_by_field[field_name] = self.generate_from_texts(texts)
        # Generate by substituting in each field.
        # TODO(lit-team): substitute on a combination of fields?
        for field_name in candidates_by_field:
            candidates = candidates_by_field[field_name]
            for i, ex in enumerate(inputs):
                for candidate in candidates[i]:
                    new_ex = utils.copy_and_update(ex, {field_name: candidate})
                    all_outputs[i].append(new_ex)
        return all_outputs
コード例 #15
0
    def _calculate_L1_distance(
            self,
            example_1: JsonDict,
            example_2: JsonDict,
            dataset: lit_dataset.Dataset,
            dataset_name: Text,
            model: Optional[lit_model.Model] = None,
            field_names: Optional[List[Text]] = None
    ) -> Tuple[float, List[Text]]:
        """Calculates L1 distance between two input examples.

    Only categorical and scalar example features are considered. For categorical
    features, the distance is calculated as the probability of the feature
    having the same for two random (with replacement) examples. For scalar
    features, the unit of distance is equal to the standard deviation of all
    feature values.

    Only features that are in the intersection of the model and dataset features
    are considered.

    If a feature value of either of the examples is None, such feature is
    ignored in distance calculation and the name of the feature is not included
    in the result feature list (see Returns description).

    Args:
      example_1: a first example to measure distance for.
      example_2: a second example to measure distance for.
      dataset: a dataset that contains the information about the feature types.
      dataset_name: name of the dataset.
      model: a model that contains the information about the input feature
        types.
      field_names: if set then the distance calculation only considers these
        fields.

    Returns:
      A tuple that contains the L1 distance and the list of features that were
      used in the distance calculation. The list of features will only contain
    """
        assert model or field_names
        distance = 0
        diff_fields = []
        if field_names is None:
            assert model
            field_names = self._find_all_fields_to_consider(
                ds_spec=dataset.spec(), model_input_spec=model.input_spec())
        for field_name in field_names:
            field_spec = dataset.spec()[field_name]
            field_stats = self._datasets_stats[dataset_name]
            assert self._is_supported(field_spec)
            assert field_name in field_stats, f'{field_name}, {field_stats.keys()}'
            if example_1[field_name] == example_2[field_name]:
                continue
            if (example_1[field_name] is None) or (example_2[field_name] is
                                                   None):
                continue
            diff_fields.append(field_name)
            if self._is_scalar(field_spec):
                std_dev = field_stats[field_name]
                if std_dev != 0:
                    distance += abs(example_1[field_name] -
                                    example_2[field_name]) / std_dev
            else:
                same_prob = field_stats[field_name]
                distance += same_prob
        return distance, diff_fields
コード例 #16
0
ファイル: word_replacer.py プロジェクト: PAIR-code/lit
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Replace words based on replacement list.

    Note: If multiple fields are selected for replacement, this method will
    generate an example per field. For example, if there are two fields on which
    to perform replacement, the method will perform replacement first on one
    field to produce an example (other fields left intact), and then perform
    replacement on the second field (again copying all other fields from the
    original datum).

    Args:
      example: the example used for basis of generated examples.
      model: the model.
      dataset: the dataset.
      config: user-provided config properties.

    Returns:
      examples: a list of generated examples.
    """
        del model  # Unused.

        config = config or {}

        ignore_casing = config.get(IGNORE_CASING_KEY, True)
        subs_string = config.get(SUBSTITUTIONS_KEY, None)
        if subs_string:
            replacements = self.parse_subs_string(subs_string,
                                                  ignore_casing=ignore_casing)
        else:
            replacements = self.default_replacements

        # If replacements dictionary is empty, do not attempt to match.
        if not replacements:
            return []

        replacement_regex = self._get_replacement_pattern(
            replacements, ignore_casing=ignore_casing)

        # If config key is missing, generate no examples.
        fields_to_replace = list(config.get(FIELDS_TO_REPLACE_KEY, []))
        if not fields_to_replace:
            return []

        # TODO(lit-dev): move this to generate_all(), so we read the spec once
        # instead of on every example.
        text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        if not text_keys:
            return []

        text_keys = [key for key in text_keys if key in fields_to_replace]

        new_examples = []
        for text_key in text_keys:
            text_data = example[text_key]
            for new_val in self.generate_counterfactuals(
                    text_data,
                    replacement_regex,
                    replacements,
                    ignore_casing=ignore_casing):
                new_example = copy.deepcopy(example)
                new_example[text_key] = new_val
                new_examples.append(new_example)

        return new_examples
コード例 #17
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:

        # Perform validation and retrieve configuration.
        if not model:
            raise ValueError('Please provide a model for this generator.')

        config = config or {}
        num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT))
        max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT))

        pred_key = config.get(PREDICTION_KEY, '')
        regression_thresh = float(
            config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT))

        dataset_name = config.get('dataset_name')
        if not dataset_name:
            raise ValueError('The dataset name must be in the config.')

        output_spec = model.output_spec()
        if not pred_key:
            raise ValueError('Please provide the prediction key.')
        if pred_key not in output_spec:
            raise ValueError('Invalid prediction key.')

        if (not (isinstance(output_spec[pred_key], lit_types.MulticlassPreds)
                 or isinstance(output_spec[pred_key],
                               lit_types.RegressionScore))):
            raise ValueError(
                'Only classification and regression models are supported')

        # Calculate dataset statistics if it has never been calculated. The
        # statistics include such information as 'standard deviation' for scalar
        # features and probabilities for categorical features.
        if dataset_name not in self._datasets_stats:
            self._calculate_stats(dataset, dataset_name)

        # Find predicted class of the original example.
        original_pred = list(model.predict([example]))[0]

        # Find dataset examples that are flips.
        filtered_examples = self._filter_ds_examples(
            dataset=dataset,
            dataset_name=dataset_name,
            model=model,
            reference_output=original_pred,
            pred_key=pred_key,
            regression_thresh=regression_thresh)

        supported_field_names = self._find_all_fields_to_consider(
            ds_spec=dataset.spec(),
            model_input_spec=model.input_spec(),
            example=example)

        candidates: List[JsonDict] = []

        # Iterate through all possible feature combinations.
        combs = utils.find_all_combinations(supported_field_names, 1,
                                            max_flips)
        for comb in combs:
            # Sort all dataset examples with respect to the given combination.
            sorted_examples = self._sort_and_filter_examples(
                examples=filtered_examples,
                ref_example=example,
                fields=comb,
                dataset=dataset,
                dataset_name=dataset_name)
            if not sorted_examples:
                continue

            # As an optimization trick, check whether the farthest example is a flip.
            # If it is not a flip then skip the current combination of features.
            # This optimization makes the minimum set guarantees weaker but
            # significantly improves the search speed.
            flip = self._find_hot_flip(ref_example=example,
                                       ds_example=sorted_examples[-1],
                                       features_to_consider=comb,
                                       model=model,
                                       target_pred=original_pred,
                                       pred_key=pred_key,
                                       dataset=dataset,
                                       interpolate=False,
                                       regression_threshold=regression_thresh)
            if not flip:
                logging.info('Skipped combination %s', comb)
                continue

            # Iterate through the sorted examples until the first flip is found.
            # TODO(b/204200758): improve performance by batching the predict requests.
            for ds_example in sorted_examples:
                flip = self._find_hot_flip(
                    ref_example=example,
                    ds_example=ds_example,
                    features_to_consider=comb,
                    model=model,
                    target_pred=original_pred,
                    pred_key=pred_key,
                    dataset=dataset,
                    interpolate=True,
                    regression_threshold=regression_thresh)

                if flip:
                    self._add_if_not_strictly_worse(example=flip,
                                                    other_examples=candidates,
                                                    ref_example=example,
                                                    dataset=dataset,
                                                    dataset_name=dataset_name,
                                                    model=model)
                    break

            if len(candidates) >= num_examples:
                break

        # Calculate distances for the found hot flips.
        candidate_tuples = []
        for flip_example in candidates:
            distance, diff_fields = self._calculate_L1_distance(
                example_1=example,
                example_2=flip_example,
                dataset=dataset,
                dataset_name=dataset_name,
                model=model)
            if distance > 0:
                candidate_tuples.append((distance, diff_fields, flip_example))

        # Order the dataset entries based on the distance to the given example.
        candidate_tuples.sort(key=lambda e: e[0])

        if len(candidate_tuples) > num_examples:
            candidate_tuples = candidate_tuples[0:num_examples]

        # e[2] contains the hot-flip examples in the distances list of tuples.
        return [e[2] for e in candidate_tuples]
コード例 #18
0
    def _find_hot_flip(
        self,
        ref_example: JsonDict,
        ds_example: JsonDict,
        features_to_consider: List[Text],
        model: lit_model.Model,
        target_pred: JsonDict,
        pred_key: Text,
        dataset: lit_dataset.Dataset,
        interpolate: bool,
        regression_threshold: Optional[float] = None,
    ) -> Optional[JsonDict]:
        """Finds a hot-flip example for a given target example and DS example.

    Args:
      ref_example: target example for which the counterfactuals should be found.
      ds_example: a dataset example that should be used as a starting point for
        the search.
      features_to_consider: the list of feature keys that can be changed during
        the search.
      model: model to use for getting predictions.
      target_pred: model prediction that corresponds to `ref_example`.
      pred_key: the name of the field in model predictions that contains the
        prediction value for the counterfactual search.
      dataset: a dataset object that contains `ds_example`.
      interpolate: if True, the method tries to find a closer counterfactual
        using interpolation.
      regression_threshold: the threshold to use if `model` is a regression
        model. This parameter is ignored for classification models.

    Returns:
      A hot-flip counterfactual that satisfy the criteria.
    """
        # All features other than `features_to_consider` should be assigned the
        # value of the target example.
        candidate_example = ds_example.copy()
        for field_name in ref_example:
            if (field_name not in features_to_consider
                    and field_name in model.input_spec()):
                candidate_example[field_name] = ref_example[field_name]

        flip, predicted_value = self._is_flip(
            model=model,
            cf_example=candidate_example,
            orig_output=target_pred,
            pred_key=pred_key,
            regression_thresh=regression_threshold)

        if not flip:
            return None

        # Find closest flip by moving scalar values closer to the target.
        closest_flip = None
        if interpolate:
            closest_flip = self._find_closer_flip_using_interpolation(
                ref_example, candidate_example, target_pred, pred_key, model,
                dataset, regression_threshold)
        # If we found a closer flip through interpolation then use it,
        # otherwise use the previously found flip.
        if closest_flip is not None:
            return closest_flip
        else:
            self._find_dataset_parent_and_set(
                model_output_spec=model.output_spec(),
                pred_key=pred_key,
                dataset_spec=dataset.spec(),
                example=candidate_example,
                predicted_value=predicted_value)
            return candidate_example
コード例 #19
0
    def _find_closer_flip_using_interpolation(
            self,
            ref_example: JsonDict,
            known_flip: JsonDict,
            target_pred: JsonDict,
            pred_key: Text,
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            regression_threshold: Optional[float] = None,
            max_attempts: int = 4) -> Optional[JsonDict]:
        """Looks for the decision boundary between two examples using interpolation.

    The method searches for a flip that is closer to the `target example` than
    `known_flip`. The method performs the binary search by interpolating scalar
    values.

    Args:
      ref_example: an example for which the flip is searched.
      known_flip: an example that represents a known flip.
      target_pred: the model prediction at `ref_example`.
      pred_key: the named of the field inside `target_pred` that holds the
        prediction value.
      model: model to use for running predictions.
      dataset: dataset that contains `known_flip`.
      regression_threshold: threshold to use for regression models.
      max_attempts: number of binary search attempts.

    Returns:
      The counterfactual (flip) if found; 'None' otherwise.
    """
        min_alpha = 0.0
        max_alpha = 1.0
        closest_flip = None
        input_spec = model.input_spec()
        has_scalar = False
        for _ in range(max_attempts):
            # Interpolate the scalar values using binary search.
            current_alpha = (min_alpha + max_alpha) / 2
            candidate = known_flip.copy()
            for field in ref_example:
                if (field in candidate and field in input_spec
                        and isinstance(input_spec[field], lit_types.Scalar)
                        and candidate[field] is not None
                        and ref_example[field] is not None):
                    candidate[field] = known_flip[field] * (
                        1 - current_alpha) + ref_example[field] * current_alpha
                    has_scalar = True
            # The interpolation makes sense only for scalar values. If there are no
            # scalar fields that can be interpolated then terminate the search.
            if not has_scalar:
                return None
            flip, predicted_value = self._is_flip(
                model=model,
                cf_example=candidate,
                orig_output=target_pred,
                pred_key=pred_key,
                regression_thresh=regression_threshold)
            if flip:
                self._find_dataset_parent_and_set(
                    model_output_spec=model.output_spec(),
                    pred_key=pred_key,
                    dataset_spec=dataset.spec(),
                    example=candidate,
                    predicted_value=predicted_value)
                closest_flip = candidate
                min_alpha = current_alpha
            else:
                max_alpha = current_alpha
        return closest_flip