Exemplo n.º 1
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        # Find gradient fields to interpret
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(input_spec, output_spec)
        logging.info('Found fields for integrated gradients: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        all_results = []
        for model_output, model_input in zip(model_outputs, inputs):
            result = self.get_salience_result(model_input, model, model_output,
                                              grad_fields)
            all_results.append(result)
        return all_results
Exemplo n.º 2
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        # Find gradient fields to interpret
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec)
        logging.info('Found fields for gradient attribution: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))
        assert len(model_outputs) == len(inputs)

        all_results = []
        for o in model_outputs:
            # Dict[field name -> interpretations]
            result = {}
            for grad_field in grad_fields:
                token_field = cast(types.TokenGradients,
                                   output_spec[grad_field]).align
                tokens = o[token_field]
                scores = self._interpret(o[grad_field], tokens)
                result[grad_field] = dtypes.SalienceMap(tokens, scores)
            all_results.append(result)

        return all_results
    def _is_flip(
            self,
            model: lit_model.Model,
            cf_example: JsonDict,
            orig_output: JsonDict,
            pred_key: Text,
            regression_thresh: Optional[float] = None) -> Tuple[bool, Any]:

        cf_output = list(model.predict([cf_example]))[0]
        feature_predicted_value = cf_output[pred_key]
        return cf_utils.is_prediction_flip(
            cf_output=cf_output,
            orig_output=orig_output,
            output_spec=model.output_spec(),
            pred_key=pred_key,
            regression_thresh=regression_thresh), feature_predicted_value
Exemplo n.º 4
0
  def _run(self,
           model: lit_model.Model,
           indexed_inputs: List[JsonDict],
           model_outputs: Optional[List[JsonDict]] = None,
           do_fit=False):
    # Run model, if needed.
    if model_outputs is None:
      model_outputs = list(model.predict(indexed_inputs))
    assert len(model_outputs) == len(indexed_inputs)

    converted_inputs = list(
        map(self.convert_input, indexed_inputs, model_outputs))
    if do_fit:
      return self._projector.fit_transform_with_metadata(
          converted_inputs, dataset_name="")
    else:
      return self._projector.predict_with_metadata(
          converted_inputs, dataset_name="")
Exemplo n.º 5
0
  def build(cls,
            inputs: List[JsonDict],
            encoder: lit_model.Model,
            edge_field: str,
            embs_field: str,
            offset_field: str,
            progress=lambda x: x,
            verbose=False):
    """Run encoder and extract span representations for coreference.

    'encoder' should be a model returning one TokenEmbeddings field,
    from which span features will be extracted, as well as a TokenOffsets field
    which maps input tokens to output tokens.

    The returned dataset will contain one example for each label in the inputs'
    EdgeLabels field.

    Args:
      inputs: input Dataset
      encoder: encoder model, compatible with inputs
      edge_field: name of edge field in data
      embs_field: name of embeddings field in model output
      offset_field: name of offset field in model output
      progress: optional pass-through progress indicator
      verbose: if true, print estimated memory usage

    Returns:
      EdgeFeaturesDataset with extracted span representations
    """
    examples = []
    encoder_outputs = progress(encoder.predict(inputs))
    for i, output in enumerate(encoder_outputs):
      exs = _make_probe_inputs(
          inputs[i][edge_field],
          output[embs_field],
          output[offset_field],
          src_idx=i)
      examples.extend(exs)
      if verbose and i == 10:
        _estimate_memory_needs(inputs, edge_field, examples[0])

    return cls(examples)
Exemplo n.º 6
0
  def run(self,
          inputs: List[JsonDict],
          model: lit_model.Model,
          dataset: lit_dataset.Dataset,
          model_outputs: Optional[List[JsonDict]] = None,
          config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
    del dataset
    del config

    field_map = self.find_fields(model)

    # Run model, if needed.
    if model_outputs is None:
      model_outputs = list(model.predict(inputs))
    assert len(model_outputs) == len(inputs)

    return [
        self._run_single(ex, mo, field_map)
        for ex, mo in zip(inputs, model_outputs)
    ]
Exemplo n.º 7
0
 def _generate_leave_one_out_ablation_score(
         self, example: JsonDict, model: lit_model.Model, input_spec: Spec,
         output_spec: Spec, orig_output: JsonDict, pred_key: Text,
         fields_to_ablate: List[str]) -> List[Tuple[str, int, float]]:
     # Returns a list of triples: field, token_idx and leave-one-out score.
     ret = []
     for field in input_spec.keys():
         if field not in example or field not in fields_to_ablate:
             continue
         tokens = self._get_tokens(example, input_spec, field)
         cfs = [
             self._create_cf(example, input_spec, [(field, i)])
             for i in range(len(tokens))
         ]
         cf_outputs = model.predict(cfs)
         for i, cf_output in enumerate(cf_outputs):
             loo_score = cf_utils.prediction_difference(
                 cf_output, orig_output, output_spec, pred_key)
             ret.append((field, i, loo_score))
     return ret
Exemplo n.º 8
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None):
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        spec = model.spec()
        field_map = map_pred_keys(dataset.spec(), spec.output,
                                  self.is_compatible)
        ret = []
        for pred_key, label_key in field_map.items():
            # Extract fields
            labels = [ex[label_key] for ex in inputs]
            preds = [mo[pred_key] for mo in model_outputs]
            # Compute metrics, as dict(str -> float)
            metrics = self.compute(
                labels,
                preds,
                label_spec=dataset.spec()[label_key],
                pred_spec=spec.output[pred_key],
                config=config.get(pred_key) if config else None)
            # NaN is not a valid JSON value, so replace with None which will be
            # serialized as null.
            # TODO(lit-team): move this logic into serialize.py somewhere instead?
            metrics = {
                k: (v if not np.isnan(v) else None)
                for k, v in metrics.items()
            }
            # Format for frontend.
            ret.append({
                'pred_key': pred_key,
                'label_key': label_key,
                'metrics': metrics
            })
        return ret
Exemplo n.º 9
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Run this component, given a model and input(s)."""
        config = config or {}
        class_to_explain = config.get(CLASS_KEY,
                                      self.config_spec()[CLASS_KEY].default)
        interpolation_steps = int(
            config.get(INTERPOLATION_KEY,
                       self.config_spec()[INTERPOLATION_KEY].default))
        normalization = config.get(
            NORMALIZATION_KEY,
            self.config_spec()[NORMALIZATION_KEY].default)

        # Find gradient fields to interpret
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(input_spec, output_spec)
        logging.info('Found fields for integrated gradients: %s',
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            return None

        # Run model, if needed.
        if model_outputs is None:
            model_outputs = list(model.predict(inputs))

        all_results = []
        for model_output, model_input in zip(model_outputs, inputs):
            result = self.get_salience_result(model_input, model,
                                              interpolation_steps,
                                              normalization, class_to_explain,
                                              model_output, grad_fields)
            all_results.append(result)
        return all_results
Exemplo n.º 10
0
    def run(self,
            inputs: List[JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[JsonDict]] = None,
            config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
        """Runs the component, given a model and input(s)."""
        input_spec = model.input_spec()
        output_spec = model.output_spec()

        if not inputs:
            return []

        # Find all fields required for the interpretation.
        supported_fields = find_supported_fields(input_spec, output_spec)
        if supported_fields is None:
            return
        grad_field_key = supported_fields.grad_field_key
        image_field_key = supported_fields.image_field_key
        grad_target_field_key = supported_fields.grad_target_field_key
        preds_field_key = supported_fields.preds_field_key
        multiclass = grad_target_field_key is not None

        # Determine the shape of gradients by calling the model with a single input
        # and extracting the shape from the gradient output.
        model_output = list(model.predict([inputs[0]]))
        grad_shape = model_output[0][grad_field_key].shape

        # If it is a multiclass model, find the labels with respect to which we
        # should compute the gradients.
        if multiclass:
            # Get class labels.
            label_vocab = list(
                cast(types.MulticlassPreds,
                     output_spec[preds_field_key]).vocab)

            # Run the model in order to find the gradient target labels.
            outputs = list(model.predict(inputs))
            grad_target_labels = []
            for model_input, model_output in zip(inputs, outputs):
                if model_input.get(grad_target_field_key) is not None:
                    grad_target_labels.append(
                        model_input[grad_target_field_key])
                else:
                    max_idx = int(np.argmax(model_output[preds_field_key]))
                    grad_target_labels.append(label_vocab[max_idx])
        else:
            grad_target_labels = [None] * len(inputs)

        saliency_object = self.get_saliency_object()
        extra_saliency_params = self.get_extra_saliency_params(config)
        all_results = []
        for example, grad_target_label in zip(inputs, grad_target_labels):
            result = {}
            image_str = example[image_field_key]
            saliency_input = image_utils.convert_image_str_to_array(
                image_str=image_str, shape=grad_shape)
            call_model_func = get_call_model_func(
                model=model,
                model_input=example,
                image_field_key=image_field_key,
                grad_field_key=grad_field_key,
                grad_target_field_key=grad_target_field_key,
                grad_target_label=grad_target_label)
            attribution = self.make_saliency_call(
                saliency_object=saliency_object,
                x_value=saliency_input,
                call_model_function=call_model_func,
                extra_saliency_params=extra_saliency_params)
            if attribution.ndim == 3:
                attribution = attribution.sum(axis=2)
            viz_params = self.get_visualization_params()
            overlaid_image = image_utils.overlay_pixel_saliency(
                image_str, attribution, **viz_params)
            result[grad_field_key] = image_utils.convert_pil_to_image_str(
                overlaid_image)
            all_results.append(result)
        return all_results
Exemplo n.º 11
0
    def get_salience_result(self, model_input: JsonDict,
                            model: lit_model.Model, model_output: JsonDict,
                            grad_fields: List[Text]):
        result = {}

        output_spec = model.output_spec()
        # We ensure that the embedding and gradient class fields are present in the
        # model's input spec in find_fields().
        embeddings_fields = [
            cast(types.TokenGradients, output_spec[grad_field]).grad_for
            for grad_field in grad_fields
        ]

        # The gradient class input is used to specify the target class of the
        # gradient calculation (if unspecified, this option defaults to the argmax,
        # which could flip between interpolated inputs).
        grad_class_key = cast(types.TokenGradients,
                              output_spec[grad_fields[0]]).grad_target
        # TODO(b/168042999): Add option to specify the class to explain in the UI.
        grad_class = model_output[grad_class_key]

        interpolated_inputs = {}
        all_embeddings = []
        all_baselines = []
        for embed_field in embeddings_fields:
            # <float32>[num_tokens, emb_size]
            embeddings = np.array(model_output[embed_field])
            all_embeddings.append(embeddings)

            # Starts with baseline of zeros. <float32>[num_tokens, emb_size]
            baseline = self.get_baseline(embeddings)
            all_baselines.append(baseline)

            # Get interpolated inputs from baseline to original embedding.
            # <float32>[interpolation_steps, num_tokens, emb_size]
            interpolated_inputs[embed_field] = self.get_interpolated_inputs(
                baseline, embeddings, self.interpolation_steps)

        # Create model inputs and populate embedding field(s).
        inputs_with_embeds = []
        for i in range(self.interpolation_steps):
            input_copy = model_input.copy()
            # Interpolates embeddings for all inputs simultaneously.
            for embed_field in embeddings_fields:
                # <float32>[num_tokens, emb_size]
                input_copy[embed_field] = interpolated_inputs[embed_field][i]
                input_copy[grad_class_key] = grad_class

            inputs_with_embeds.append(input_copy)
        embed_outputs = model.predict(inputs_with_embeds)

        # Create list with concatenated gradients for each interpolate input.
        gradients = []
        for o in embed_outputs:
            # <float32>[total_num_tokens, emb_size]
            interp_gradients = np.concatenate(
                [o[field] for field in grad_fields])
            gradients.append(interp_gradients)
        # <float32>[interpolation_steps, total_num_tokens, emb_size]
        path_gradients = np.stack(gradients, axis=0)

        # Calculate integral
        # <float32>[total_num_tokens, emb_size]
        integral = self.estimate_integral(path_gradients)

        # <float32>[total_num_tokens, emb_size]
        concat_embeddings = np.concatenate(all_embeddings)

        # <float32>[total_num_tokens, emb_size]
        concat_baseline = np.concatenate(all_baselines)

        # <float32>[total_num_tokens, emb_size]
        integrated_gradients = integral * (np.array(concat_embeddings) -
                                           np.array(concat_baseline))
        # Dot product of integral values and (embeddings - baseline).
        # <float32>[total_num_tokens]
        attributions = np.sum(integrated_gradients, axis=-1)

        # TODO(b/168042999): Make normalization customizable in the UI.
        # <float32>[total_num_tokens]
        scores = citrus_utils.normalize_scores(attributions)

        for grad_field in grad_fields:
            # Format as salience map result.
            token_field = cast(types.TokenGradients,
                               output_spec[grad_field]).align
            tokens = model_output[token_field]

            # Only use the scores that correspond to the tokens in this grad_field.
            # The gradients for all input embeddings were concatenated in the order
            # of the grad fields, so they can be sliced out in the same order.
            sliced_scores = scores[:len(
                tokens)]  # <float32>[num_tokens in field]
            scores = scores[len(tokens):]  # <float32>[num_remaining_tokens]

            assert len(tokens) == len(sliced_scores)
            result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores)
        return result
Exemplo n.º 12
0
Arquivo: pdp.py Projeto: PAIR-code/lit
    def run(self,
            inputs: List[types.JsonDict],
            model: lit_model.Model,
            dataset: lit_dataset.Dataset,
            model_outputs: Optional[List[types.JsonDict]] = None,
            config: Optional[types.JsonDict] = None):
        """Create PDP chart info using provided inputs.

    Args:
      inputs: sequence of inputs, following model.input_spec()
      model: optional model to use to generate new examples.
      dataset: dataset which the current examples belong to.
      model_outputs: optional precomputed model outputs
      config: optional runtime config.

    Returns:
      a dict of alternate feature values to model outputs. The model
      outputs will be a number for regression models and a list of numbers for
      multiclass models.
    """

        pred_keys = utils.find_spec_keys(
            model.output_spec(),
            (types.MulticlassPreds, types.RegressionScore))
        if not pred_keys:
            logging.warning('PDP did not find any supported output fields.')
            return None

        assert 'feature' in config, 'No feature to test provided'
        feature = config['feature']
        provided_range = config['range'] if 'range' in config else []
        edited_outputs = {}
        for pred_key in pred_keys:
            edited_outputs[pred_key] = {}

        # If a range was provided, use that to create the possible values.
        vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10)
                        if len(provided_range) == 2 else self.get_vals_to_test(
                            feature, dataset))

        # If no specific inputs provided, use the entire dataset.
        inputs_to_use = inputs if inputs else dataset.examples

        # For each alternate value for a given feature.
        for new_val in vals_to_test:
            # Create copies of all provided inputs with the value replaced.
            edited_inputs = []
            for inp in inputs_to_use:
                edited_input = copy.deepcopy(inp)
                edited_input[feature] = new_val
                edited_inputs.append(edited_input)

            # Run prediction on the altered inputs.
            outputs = list(model.predict(edited_inputs))

            # Store the mean of the prediction for the alternate value.
            for pred_key in pred_keys:
                numeric = isinstance(model.output_spec()[pred_key],
                                     types.RegressionScore)
                if numeric:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs])
                else:
                    edited_outputs[pred_key][new_val] = np.mean(
                        [output[pred_key] for output in outputs], axis=0)

        return edited_outputs
Exemplo n.º 13
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Identify minimal sets of token albations that alter the prediction."""
        del dataset  # Unused.

        config = config or {}
        num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT))
        max_ablations = int(
            config.get(MAX_ABLATIONS_KEY, MAX_ABLATIONS_DEFAULT))
        assert model is not None, "Please provide a model for this generator."

        input_spec = model.input_spec()
        pred_key = config.get(PREDICTION_KEY, "")
        regression_thresh = float(
            config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT))

        output_spec = model.output_spec()
        assert pred_key, "Please provide the prediction key"
        assert pred_key in output_spec, "Invalid prediction key"

        is_regression = isinstance(output_spec[pred_key],
                                   types.RegressionScore)
        if not is_regression:
            assert isinstance(output_spec[pred_key], types.MulticlassPreds), (
                "Only classification or regression models are supported")
        logging.info(r"W3lc0m3 t0 Ablatl0nFl1p \o/")
        logging.info("Original example: %r", example)

        # Check for fields to ablate.
        fields_to_ablate = list(config.get(FIELDS_TO_ABLATE_KEY, []))
        if not fields_to_ablate:
            return []

        # Get model outputs.
        orig_output = list(model.predict([example]))[0]
        loo_scores = self._generate_leave_one_out_ablation_score(
            example, model, input_spec, output_spec, orig_output, pred_key,
            fields_to_ablate)

        if isinstance(output_spec[pred_key], types.RegressionScore):
            ablation_idxs_generator = self._gen_ablation_idxs(
                loo_scores, max_ablations, orig_output[pred_key],
                regression_thresh)
        else:
            ablation_idxs_generator = self._gen_ablation_idxs(
                loo_scores, max_ablations)

        tokens_map = {}
        for field in input_spec.keys():
            tokens = self._get_tokens(example, input_spec, field)
            if not tokens:
                continue
            tokens_map[field] = tokens

        successful_cfs = []
        successful_positions = []
        for ablation_idxs in ablation_idxs_generator:
            if len(successful_cfs) >= num_examples:
                return successful_cfs

            # If a subset of the set of tokens have already been successful in
            # obtaining a flip, we continue. This ensures that we only consider
            # sets of tokens that are minimal.
            if self._subset_exists(set(ablation_idxs), successful_positions):
                continue

            # Create counterfactual and obtain model prediction.
            cf = self._create_cf(example, input_spec, ablation_idxs)
            cf_output = list(model.predict([cf]))[0]

            # Check if counterfactual results in a prediction flip.
            if cf_utils.is_prediction_flip(cf_output, orig_output, output_spec,
                                           pred_key, regression_thresh):
                # Prediction flip found!
                cf_utils.update_prediction(cf, cf_output, output_spec,
                                           pred_key)
                cf[ABLATED_TOKENS_KEY] = str([
                    f"{field}[{tokens_map[field][idx]}]"
                    for field, idx in ablation_idxs
                ])
                successful_cfs.append(cf)
                successful_positions.append(set(ablation_idxs))
        return successful_cfs
Exemplo n.º 14
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None,
                 num_examples: int = 1) -> List[JsonDict]:
        """Use gradient to find/substitute the token with largest impact on loss."""
        # TODO(lit-team): This function is quite long. Consider breaking it
        # into small functions.
        del dataset  # Unused.

        assert model is not None, "Please provide a model for this generator."
        logging.info(r"W3lc0m3 t0 H0tFl1p \o/")
        logging.info("Original example: %r", example)

        # Find classification prediciton key.
        pred_keys = self.find_fields(model.output_spec(),
                                     types.MulticlassPreds, None)
        if len(pred_keys) == 0:  # pylint: disable=g-explicit-length-test
            # TODO(ataly): Add support for regression models.
            logging.warning("The model does not have a classification head."
                            "Cannot use HotFlip. :-(")
            return []  # Cannot generate examples.
        if len(pred_keys) > 1:
            # TODO(ataly): Use a config argument when there are multiple prediction
            # heads.
            logging.warning("Multiple classification heads found."
                            "Cannot use HotFlip. :-(")
            return []  # Cannot generate examples.
        pred_key = pred_keys[0]

        # Find gradient fields to use for HotFlip
        input_spec = model.input_spec()
        output_spec = model.output_spec()
        grad_fields = self.find_fields(output_spec, types.TokenGradients,
                                       types.Tokens)
        logging.info("Found gradient fields for HotFlip use: %s",
                     str(grad_fields))
        if len(grad_fields) == 0:  # pylint: disable=g-explicit-length-test
            logging.info("No gradient fields found. Cannot use HotFlip. :-(")
            return []  # Cannot generate examples without gradients.

        # Get model outputs.
        logging.info(
            "Performing a forward/backward pass on the input example.")
        orig_output = list(model.predict([example]))[0]
        logging.info(orig_output.keys())

        # Get model word embeddings and vocab.
        inv_vocab, embed = model.get_embedding_table()
        assert len(
            inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch."
        logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab),
                     embed.shape)

        # Get original prediction class
        orig_probabilities = orig_output[pred_key]
        orig_prediction = np.argmax(orig_probabilities)

        # Perform a flip in each sequence for which we have gradients (separately).
        # Each sequence may give rise to multiple new examples, depending on how
        # many words we flip.
        # TODO(lit-team): make configurable how many new examples are desired.
        # TODO(lit-team): use only 1 sequence as input (configurable in UI).
        new_examples = []
        for grad_field in grad_fields:
            # Get the tokens and their gradient vectors.
            token_field = output_spec[grad_field].align  # pytype: disable=attribute-error
            tokens = orig_output[token_field]
            grads = orig_output[grad_field]
            token_emb_fields = self.find_fields(output_spec,
                                                types.TokenEmbeddings,
                                                types.Tokens)
            assert len(
                token_emb_fields) == 1, "Found multiple token embeddings"
            token_embs = orig_output[token_emb_fields[0]]

            # Identify the token with the largest gradient attribution,
            # defined as the dot product between the token embedding and gradient
            # of the output wrt the embedding.
            assert token_embs.shape[0] == grads.shape[0]
            token_grad_attrs = np.sum(token_embs * grads, axis=-1)
            # Get a list of indices of input tokens, sorted by gradient attribution,
            # highest first. We will flip tokens in this order.
            sorted_by_grad_attrs = np.argsort(token_grad_attrs)[::-1]

            for i in range(min(num_examples, len(tokens))):
                token_id = sorted_by_grad_attrs[i]
                logging.info(
                    "Selected token: %s (pos=%d) with gradient attribution %f",
                    tokens[token_id], token_id, token_grad_attrs[token_id])
                token_grad = grads[token_id]

                # Take dot product with all word embeddings. Get smallest value.
                # (We are look for a replacement token that will lower the score
                # the current class, thereby increasing the chances of a label
                # flip.)
                # TODO(lit-team): Can add criteria to the winner e.g. cosine distance.
                scores = np.dot(embed, token_grad)
                winner = np.argmin(scores)
                logging.info(
                    "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)",
                    tokens[token_id], token_id, i, inv_vocab[winner], winner)

                # Create a new input to the model.
                # TODO(iftenney, bastings): enforce somewhere that this field has the
                # same name in the input and output specs.
                input_token_field = token_field
                input_text_field = input_spec[input_token_field].parent  # pytype: disable=attribute-error
                new_example = copy.deepcopy(example)
                modified_tokens = copy.copy(tokens)
                modified_tokens[token_id] = inv_vocab[winner]
                new_example[input_token_field] = modified_tokens
                # TODO(iftenney, bastings): call a model-provided detokenizer here?
                # Though in general tokenization isn't invertible and it's possible for
                # HotFlip to produce wordpiece sequences that don't correspond to any
                # input string.
                new_example[input_text_field] = " ".join(modified_tokens)

                # Predict a new label for this example.
                new_output = list(model.predict([new_example]))[0]

                # Update label if multi-class prediction.
                # TODO(lit-dev): provide a general system for handling labels on
                # generated examples.
                probabilities = new_output[pred_key]
                new_prediction = np.argmax(probabilities)
                label_key = cast(types.MulticlassPreds,
                                 output_spec[pred_key]).parent
                label_names = cast(types.MulticlassPreds,
                                   output_spec[pred_key]).vocab
                new_label = label_names[new_prediction]
                new_example[label_key] = new_label
                logging.info("Updated example with new label: %s", new_label)

                if new_prediction != orig_prediction:
                    # Hotflip found
                    new_examples.append(new_example)
                else:
                    # We make new_example as our base example and continue with more
                    # token flips.
                    example = new_example
                    tokens = modified_tokens
        return new_examples
Exemplo n.º 15
0
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:
        """Identify minimal sets of token flips that alter the prediction."""
        del dataset  # Unused.

        config = config or {}
        num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT))
        max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT))
        tokens_to_ignore = config.get(TOKENS_TO_IGNORE_KEY,
                                      TOKENS_TO_IGNORE_DEFAULT)
        pred_key = config.get(PREDICTION_KEY, "")
        regression_thresh = float(
            config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT))
        assert model is not None, "Please provide a model for this generator."

        input_spec = model.input_spec()
        output_spec = model.output_spec()
        assert pred_key, "Please provide the prediction key"
        assert pred_key in output_spec, "Invalid prediction key"

        is_regression = False
        if isinstance(output_spec[pred_key], types.RegressionScore):
            is_regression = True
        else:
            assert isinstance(output_spec[pred_key], types.MulticlassPreds), (
                "Only classification or regression models are supported")
        logging.info(r"W3lc0m3 t0 H0tFl1p \o/")
        logging.info("Original example: %r", example)

        # Get model outputs.
        orig_output = list(model.predict([example]))[0]

        # Check config for selected fields.
        selected_fields = list(config.get(FIELDS_TO_HOTFLIP_KEY, []))
        if not selected_fields:
            return []

        # Get tokens (corresponding to each text input field) and corresponding
        # gradients.
        tokens_and_gradients = self._get_tokens_and_gradients(
            input_spec, output_spec, orig_output, selected_fields)
        assert tokens_and_gradients, (
            "No token fields found. Cannot use HotFlip. :-(")

        # Copy tokens into input example.
        example = copy.deepcopy(example)
        for token_field, v in tokens_and_gradients.items():
            tokens, _ = v
            example[token_field] = tokens

        inv_vocab, embedding_matrix = model.get_embedding_table()
        assert len(inv_vocab) == embedding_matrix.shape[0], (
            "Vocab/embeddings size mismatch.")

        successful_cfs = []
        # TODO(lit-team): use only 1 sequence as input (configurable in UI).
        # TODO(lit-team): Refactor the following code so that it's not so deeply
        # nested (and easier to track loop state).
        for token_field, v in tokens_and_gradients.items():
            tokens, grads = v
            text_field = input_spec[token_field].parent  # pytype: disable=attribute-error
            logging.info("Identifying Hotflips for input field: %s",
                         str(text_field))
            direction = -1
            if is_regression:
                # We want the replacements to increase the prediction score if the
                # original score is below the threshold, and decrease otherwise.
                direction = (1 if orig_output[pred_key] <= regression_thresh
                             else -1)
            replacement_tokens = self._get_replacement_tokens(
                embedding_matrix, inv_vocab, grads, direction)

            successful_positions = []
            for token_idxs in self._gen_token_idxs_to_flip(
                    tokens, grads, max_flips, tokens_to_ignore):
                if len(successful_cfs) >= num_examples:
                    return successful_cfs
                # If a subset of the set of tokens have already been successful in
                # obtaining a flip, we continue. This ensures that we only consider
                # sets of token flips that are minimal.
                if self._subset_exists(set(token_idxs), successful_positions):
                    continue

                # Create counterfactual.
                cf = self._create_cf(example, token_field, text_field, tokens,
                                     token_idxs, replacement_tokens)
                # Obtain model prediction.
                cf_output = list(model.predict([cf]))[0]

                if cf_utils.is_prediction_flip(cf_output, orig_output,
                                               output_spec, pred_key,
                                               regression_thresh):
                    # Prediciton flip found!
                    cf_utils.update_prediction(cf, cf_output, output_spec,
                                               pred_key)
                    successful_cfs.append(cf)
                    successful_positions.append(set(token_idxs))
        return successful_cfs
    def generate(self,
                 example: JsonDict,
                 model: lit_model.Model,
                 dataset: lit_dataset.Dataset,
                 config: Optional[JsonDict] = None) -> List[JsonDict]:

        # Perform validation and retrieve configuration.
        if not model:
            raise ValueError('Please provide a model for this generator.')

        config = config or {}
        num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT))
        max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT))

        pred_key = config.get(PREDICTION_KEY, '')
        regression_thresh = float(
            config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT))

        dataset_name = config.get('dataset_name')
        if not dataset_name:
            raise ValueError('The dataset name must be in the config.')

        output_spec = model.output_spec()
        if not pred_key:
            raise ValueError('Please provide the prediction key.')
        if pred_key not in output_spec:
            raise ValueError('Invalid prediction key.')

        if (not (isinstance(output_spec[pred_key], lit_types.MulticlassPreds)
                 or isinstance(output_spec[pred_key],
                               lit_types.RegressionScore))):
            raise ValueError(
                'Only classification and regression models are supported')

        # Calculate dataset statistics if it has never been calculated. The
        # statistics include such information as 'standard deviation' for scalar
        # features and probabilities for categorical features.
        if dataset_name not in self._datasets_stats:
            self._calculate_stats(dataset, dataset_name)

        # Find predicted class of the original example.
        original_pred = list(model.predict([example]))[0]

        # Find dataset examples that are flips.
        filtered_examples = self._filter_ds_examples(
            dataset=dataset,
            dataset_name=dataset_name,
            model=model,
            reference_output=original_pred,
            pred_key=pred_key,
            regression_thresh=regression_thresh)

        supported_field_names = self._find_all_fields_to_consider(
            ds_spec=dataset.spec(),
            model_input_spec=model.input_spec(),
            example=example)

        candidates: List[JsonDict] = []

        # Iterate through all possible feature combinations.
        combs = utils.find_all_combinations(supported_field_names, 1,
                                            max_flips)
        for comb in combs:
            # Sort all dataset examples with respect to the given combination.
            sorted_examples = self._sort_and_filter_examples(
                examples=filtered_examples,
                ref_example=example,
                fields=comb,
                dataset=dataset,
                dataset_name=dataset_name)
            if not sorted_examples:
                continue

            # As an optimization trick, check whether the farthest example is a flip.
            # If it is not a flip then skip the current combination of features.
            # This optimization makes the minimum set guarantees weaker but
            # significantly improves the search speed.
            flip = self._find_hot_flip(ref_example=example,
                                       ds_example=sorted_examples[-1],
                                       features_to_consider=comb,
                                       model=model,
                                       target_pred=original_pred,
                                       pred_key=pred_key,
                                       dataset=dataset,
                                       interpolate=False,
                                       regression_threshold=regression_thresh)
            if not flip:
                logging.info('Skipped combination %s', comb)
                continue

            # Iterate through the sorted examples until the first flip is found.
            # TODO(b/204200758): improve performance by batching the predict requests.
            for ds_example in sorted_examples:
                flip = self._find_hot_flip(
                    ref_example=example,
                    ds_example=ds_example,
                    features_to_consider=comb,
                    model=model,
                    target_pred=original_pred,
                    pred_key=pred_key,
                    dataset=dataset,
                    interpolate=True,
                    regression_threshold=regression_thresh)

                if flip:
                    self._add_if_not_strictly_worse(example=flip,
                                                    other_examples=candidates,
                                                    ref_example=example,
                                                    dataset=dataset,
                                                    dataset_name=dataset_name,
                                                    model=model)
                    break

            if len(candidates) >= num_examples:
                break

        # Calculate distances for the found hot flips.
        candidate_tuples = []
        for flip_example in candidates:
            distance, diff_fields = self._calculate_L1_distance(
                example_1=example,
                example_2=flip_example,
                dataset=dataset,
                dataset_name=dataset_name,
                model=model)
            if distance > 0:
                candidate_tuples.append((distance, diff_fields, flip_example))

        # Order the dataset entries based on the distance to the given example.
        candidate_tuples.sort(key=lambda e: e[0])

        if len(candidate_tuples) > num_examples:
            candidate_tuples = candidate_tuples[0:num_examples]

        # e[2] contains the hot-flip examples in the distances list of tuples.
        return [e[2] for e in candidate_tuples]