def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find gradient fields to interpret input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(input_spec, output_spec) logging.info('Found fields for integrated gradients: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) all_results = [] for model_output, model_input in zip(model_outputs, inputs): result = self.get_salience_result(model_input, model, model_output, grad_fields) all_results.append(result) return all_results
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find gradient fields to interpret output_spec = model.output_spec() grad_fields = self.find_fields(output_spec) logging.info('Found fields for gradient attribution: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) assert len(model_outputs) == len(inputs) all_results = [] for o in model_outputs: # Dict[field name -> interpretations] result = {} for grad_field in grad_fields: token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = o[token_field] scores = self._interpret(o[grad_field], tokens) result[grad_field] = dtypes.SalienceMap(tokens, scores) all_results.append(result) return all_results
def _is_flip( self, model: lit_model.Model, cf_example: JsonDict, orig_output: JsonDict, pred_key: Text, regression_thresh: Optional[float] = None) -> Tuple[bool, Any]: cf_output = list(model.predict([cf_example]))[0] feature_predicted_value = cf_output[pred_key] return cf_utils.is_prediction_flip( cf_output=cf_output, orig_output=orig_output, output_spec=model.output_spec(), pred_key=pred_key, regression_thresh=regression_thresh), feature_predicted_value
def _run(self, model: lit_model.Model, indexed_inputs: List[JsonDict], model_outputs: Optional[List[JsonDict]] = None, do_fit=False): # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(indexed_inputs)) assert len(model_outputs) == len(indexed_inputs) converted_inputs = list( map(self.convert_input, indexed_inputs, model_outputs)) if do_fit: return self._projector.fit_transform_with_metadata( converted_inputs, dataset_name="") else: return self._projector.predict_with_metadata( converted_inputs, dataset_name="")
def build(cls, inputs: List[JsonDict], encoder: lit_model.Model, edge_field: str, embs_field: str, offset_field: str, progress=lambda x: x, verbose=False): """Run encoder and extract span representations for coreference. 'encoder' should be a model returning one TokenEmbeddings field, from which span features will be extracted, as well as a TokenOffsets field which maps input tokens to output tokens. The returned dataset will contain one example for each label in the inputs' EdgeLabels field. Args: inputs: input Dataset encoder: encoder model, compatible with inputs edge_field: name of edge field in data embs_field: name of embeddings field in model output offset_field: name of offset field in model output progress: optional pass-through progress indicator verbose: if true, print estimated memory usage Returns: EdgeFeaturesDataset with extracted span representations """ examples = [] encoder_outputs = progress(encoder.predict(inputs)) for i, output in enumerate(encoder_outputs): exs = _make_probe_inputs( inputs[i][edge_field], output[embs_field], output[offset_field], src_idx=i) examples.extend(exs) if verbose and i == 10: _estimate_memory_needs(inputs, edge_field, examples[0]) return cls(examples)
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: del dataset del config field_map = self.find_fields(model) # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) assert len(model_outputs) == len(inputs) return [ self._run_single(ex, mo, field_map) for ex, mo in zip(inputs, model_outputs) ]
def _generate_leave_one_out_ablation_score( self, example: JsonDict, model: lit_model.Model, input_spec: Spec, output_spec: Spec, orig_output: JsonDict, pred_key: Text, fields_to_ablate: List[str]) -> List[Tuple[str, int, float]]: # Returns a list of triples: field, token_idx and leave-one-out score. ret = [] for field in input_spec.keys(): if field not in example or field not in fields_to_ablate: continue tokens = self._get_tokens(example, input_spec, field) cfs = [ self._create_cf(example, input_spec, [(field, i)]) for i in range(len(tokens)) ] cf_outputs = model.predict(cfs) for i, cf_output in enumerate(cf_outputs): loo_score = cf_utils.prediction_difference( cf_output, orig_output, output_spec, pred_key) ret.append((field, i, loo_score)) return ret
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None): if model_outputs is None: model_outputs = list(model.predict(inputs)) spec = model.spec() field_map = map_pred_keys(dataset.spec(), spec.output, self.is_compatible) ret = [] for pred_key, label_key in field_map.items(): # Extract fields labels = [ex[label_key] for ex in inputs] preds = [mo[pred_key] for mo in model_outputs] # Compute metrics, as dict(str -> float) metrics = self.compute( labels, preds, label_spec=dataset.spec()[label_key], pred_spec=spec.output[pred_key], config=config.get(pred_key) if config else None) # NaN is not a valid JSON value, so replace with None which will be # serialized as null. # TODO(lit-team): move this logic into serialize.py somewhere instead? metrics = { k: (v if not np.isnan(v) else None) for k, v in metrics.items() } # Format for frontend. ret.append({ 'pred_key': pred_key, 'label_key': label_key, 'metrics': metrics }) return ret
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" config = config or {} class_to_explain = config.get(CLASS_KEY, self.config_spec()[CLASS_KEY].default) interpolation_steps = int( config.get(INTERPOLATION_KEY, self.config_spec()[INTERPOLATION_KEY].default)) normalization = config.get( NORMALIZATION_KEY, self.config_spec()[NORMALIZATION_KEY].default) # Find gradient fields to interpret input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(input_spec, output_spec) logging.info('Found fields for integrated gradients: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) all_results = [] for model_output, model_input in zip(model_outputs, inputs): result = self.get_salience_result(model_input, model, interpolation_steps, normalization, class_to_explain, model_output, grad_fields) all_results.append(result) return all_results
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Runs the component, given a model and input(s).""" input_spec = model.input_spec() output_spec = model.output_spec() if not inputs: return [] # Find all fields required for the interpretation. supported_fields = find_supported_fields(input_spec, output_spec) if supported_fields is None: return grad_field_key = supported_fields.grad_field_key image_field_key = supported_fields.image_field_key grad_target_field_key = supported_fields.grad_target_field_key preds_field_key = supported_fields.preds_field_key multiclass = grad_target_field_key is not None # Determine the shape of gradients by calling the model with a single input # and extracting the shape from the gradient output. model_output = list(model.predict([inputs[0]])) grad_shape = model_output[0][grad_field_key].shape # If it is a multiclass model, find the labels with respect to which we # should compute the gradients. if multiclass: # Get class labels. label_vocab = list( cast(types.MulticlassPreds, output_spec[preds_field_key]).vocab) # Run the model in order to find the gradient target labels. outputs = list(model.predict(inputs)) grad_target_labels = [] for model_input, model_output in zip(inputs, outputs): if model_input.get(grad_target_field_key) is not None: grad_target_labels.append( model_input[grad_target_field_key]) else: max_idx = int(np.argmax(model_output[preds_field_key])) grad_target_labels.append(label_vocab[max_idx]) else: grad_target_labels = [None] * len(inputs) saliency_object = self.get_saliency_object() extra_saliency_params = self.get_extra_saliency_params(config) all_results = [] for example, grad_target_label in zip(inputs, grad_target_labels): result = {} image_str = example[image_field_key] saliency_input = image_utils.convert_image_str_to_array( image_str=image_str, shape=grad_shape) call_model_func = get_call_model_func( model=model, model_input=example, image_field_key=image_field_key, grad_field_key=grad_field_key, grad_target_field_key=grad_target_field_key, grad_target_label=grad_target_label) attribution = self.make_saliency_call( saliency_object=saliency_object, x_value=saliency_input, call_model_function=call_model_func, extra_saliency_params=extra_saliency_params) if attribution.ndim == 3: attribution = attribution.sum(axis=2) viz_params = self.get_visualization_params() overlaid_image = image_utils.overlay_pixel_saliency( image_str, attribution, **viz_params) result[grad_field_key] = image_utils.convert_pil_to_image_str( overlaid_image) all_results.append(result) return all_results
def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, model_output: JsonDict, grad_fields: List[Text]): result = {} output_spec = model.output_spec() # We ensure that the embedding and gradient class fields are present in the # model's input spec in find_fields(). embeddings_fields = [ cast(types.TokenGradients, output_spec[grad_field]).grad_for for grad_field in grad_fields ] # The gradient class input is used to specify the target class of the # gradient calculation (if unspecified, this option defaults to the argmax, # which could flip between interpolated inputs). grad_class_key = cast(types.TokenGradients, output_spec[grad_fields[0]]).grad_target # TODO(b/168042999): Add option to specify the class to explain in the UI. grad_class = model_output[grad_class_key] interpolated_inputs = {} all_embeddings = [] all_baselines = [] for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] embeddings = np.array(model_output[embed_field]) all_embeddings.append(embeddings) # Starts with baseline of zeros. <float32>[num_tokens, emb_size] baseline = self.get_baseline(embeddings) all_baselines.append(baseline) # Get interpolated inputs from baseline to original embedding. # <float32>[interpolation_steps, num_tokens, emb_size] interpolated_inputs[embed_field] = self.get_interpolated_inputs( baseline, embeddings, self.interpolation_steps) # Create model inputs and populate embedding field(s). inputs_with_embeds = [] for i in range(self.interpolation_steps): input_copy = model_input.copy() # Interpolates embeddings for all inputs simultaneously. for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] input_copy[embed_field] = interpolated_inputs[embed_field][i] input_copy[grad_class_key] = grad_class inputs_with_embeds.append(input_copy) embed_outputs = model.predict(inputs_with_embeds) # Create list with concatenated gradients for each interpolate input. gradients = [] for o in embed_outputs: # <float32>[total_num_tokens, emb_size] interp_gradients = np.concatenate( [o[field] for field in grad_fields]) gradients.append(interp_gradients) # <float32>[interpolation_steps, total_num_tokens, emb_size] path_gradients = np.stack(gradients, axis=0) # Calculate integral # <float32>[total_num_tokens, emb_size] integral = self.estimate_integral(path_gradients) # <float32>[total_num_tokens, emb_size] concat_embeddings = np.concatenate(all_embeddings) # <float32>[total_num_tokens, emb_size] concat_baseline = np.concatenate(all_baselines) # <float32>[total_num_tokens, emb_size] integrated_gradients = integral * (np.array(concat_embeddings) - np.array(concat_baseline)) # Dot product of integral values and (embeddings - baseline). # <float32>[total_num_tokens] attributions = np.sum(integrated_gradients, axis=-1) # TODO(b/168042999): Make normalization customizable in the UI. # <float32>[total_num_tokens] scores = citrus_utils.normalize_scores(attributions) for grad_field in grad_fields: # Format as salience map result. token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = model_output[token_field] # Only use the scores that correspond to the tokens in this grad_field. # The gradients for all input embeddings were concatenated in the order # of the grad fields, so they can be sliced out in the same order. sliced_scores = scores[:len( tokens)] # <float32>[num_tokens in field] scores = scores[len(tokens):] # <float32>[num_remaining_tokens] assert len(tokens) == len(sliced_scores) result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores) return result
def run(self, inputs: List[types.JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[types.JsonDict]] = None, config: Optional[types.JsonDict] = None): """Create PDP chart info using provided inputs. Args: inputs: sequence of inputs, following model.input_spec() model: optional model to use to generate new examples. dataset: dataset which the current examples belong to. model_outputs: optional precomputed model outputs config: optional runtime config. Returns: a dict of alternate feature values to model outputs. The model outputs will be a number for regression models and a list of numbers for multiclass models. """ pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('PDP did not find any supported output fields.') return None assert 'feature' in config, 'No feature to test provided' feature = config['feature'] provided_range = config['range'] if 'range' in config else [] edited_outputs = {} for pred_key in pred_keys: edited_outputs[pred_key] = {} # If a range was provided, use that to create the possible values. vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10) if len(provided_range) == 2 else self.get_vals_to_test( feature, dataset)) # If no specific inputs provided, use the entire dataset. inputs_to_use = inputs if inputs else dataset.examples # For each alternate value for a given feature. for new_val in vals_to_test: # Create copies of all provided inputs with the value replaced. edited_inputs = [] for inp in inputs_to_use: edited_input = copy.deepcopy(inp) edited_input[feature] = new_val edited_inputs.append(edited_input) # Run prediction on the altered inputs. outputs = list(model.predict(edited_inputs)) # Store the mean of the prediction for the alternate value. for pred_key in pred_keys: numeric = isinstance(model.output_spec()[pred_key], types.RegressionScore) if numeric: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs]) else: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs], axis=0) return edited_outputs
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Identify minimal sets of token albations that alter the prediction.""" del dataset # Unused. config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_ablations = int( config.get(MAX_ABLATIONS_KEY, MAX_ABLATIONS_DEFAULT)) assert model is not None, "Please provide a model for this generator." input_spec = model.input_spec() pred_key = config.get(PREDICTION_KEY, "") regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) output_spec = model.output_spec() assert pred_key, "Please provide the prediction key" assert pred_key in output_spec, "Invalid prediction key" is_regression = isinstance(output_spec[pred_key], types.RegressionScore) if not is_regression: assert isinstance(output_spec[pred_key], types.MulticlassPreds), ( "Only classification or regression models are supported") logging.info(r"W3lc0m3 t0 Ablatl0nFl1p \o/") logging.info("Original example: %r", example) # Check for fields to ablate. fields_to_ablate = list(config.get(FIELDS_TO_ABLATE_KEY, [])) if not fields_to_ablate: return [] # Get model outputs. orig_output = list(model.predict([example]))[0] loo_scores = self._generate_leave_one_out_ablation_score( example, model, input_spec, output_spec, orig_output, pred_key, fields_to_ablate) if isinstance(output_spec[pred_key], types.RegressionScore): ablation_idxs_generator = self._gen_ablation_idxs( loo_scores, max_ablations, orig_output[pred_key], regression_thresh) else: ablation_idxs_generator = self._gen_ablation_idxs( loo_scores, max_ablations) tokens_map = {} for field in input_spec.keys(): tokens = self._get_tokens(example, input_spec, field) if not tokens: continue tokens_map[field] = tokens successful_cfs = [] successful_positions = [] for ablation_idxs in ablation_idxs_generator: if len(successful_cfs) >= num_examples: return successful_cfs # If a subset of the set of tokens have already been successful in # obtaining a flip, we continue. This ensures that we only consider # sets of tokens that are minimal. if self._subset_exists(set(ablation_idxs), successful_positions): continue # Create counterfactual and obtain model prediction. cf = self._create_cf(example, input_spec, ablation_idxs) cf_output = list(model.predict([cf]))[0] # Check if counterfactual results in a prediction flip. if cf_utils.is_prediction_flip(cf_output, orig_output, output_spec, pred_key, regression_thresh): # Prediction flip found! cf_utils.update_prediction(cf, cf_output, output_spec, pred_key) cf[ABLATED_TOKENS_KEY] = str([ f"{field}[{tokens_map[field][idx]}]" for field, idx in ablation_idxs ]) successful_cfs.append(cf) successful_positions.append(set(ablation_idxs)) return successful_cfs
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None, num_examples: int = 1) -> List[JsonDict]: """Use gradient to find/substitute the token with largest impact on loss.""" # TODO(lit-team): This function is quite long. Consider breaking it # into small functions. del dataset # Unused. assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) # Find classification prediciton key. pred_keys = self.find_fields(model.output_spec(), types.MulticlassPreds, None) if len(pred_keys) == 0: # pylint: disable=g-explicit-length-test # TODO(ataly): Add support for regression models. logging.warning("The model does not have a classification head." "Cannot use HotFlip. :-(") return [] # Cannot generate examples. if len(pred_keys) > 1: # TODO(ataly): Use a config argument when there are multiple prediction # heads. logging.warning("Multiple classification heads found." "Cannot use HotFlip. :-(") return [] # Cannot generate examples. pred_key = pred_keys[0] # Find gradient fields to use for HotFlip input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(output_spec, types.TokenGradients, types.Tokens) logging.info("Found gradient fields for HotFlip use: %s", str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test logging.info("No gradient fields found. Cannot use HotFlip. :-(") return [] # Cannot generate examples without gradients. # Get model outputs. logging.info( "Performing a forward/backward pass on the input example.") orig_output = list(model.predict([example]))[0] logging.info(orig_output.keys()) # Get model word embeddings and vocab. inv_vocab, embed = model.get_embedding_table() assert len( inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch." logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab), embed.shape) # Get original prediction class orig_probabilities = orig_output[pred_key] orig_prediction = np.argmax(orig_probabilities) # Perform a flip in each sequence for which we have gradients (separately). # Each sequence may give rise to multiple new examples, depending on how # many words we flip. # TODO(lit-team): make configurable how many new examples are desired. # TODO(lit-team): use only 1 sequence as input (configurable in UI). new_examples = [] for grad_field in grad_fields: # Get the tokens and their gradient vectors. token_field = output_spec[grad_field].align # pytype: disable=attribute-error tokens = orig_output[token_field] grads = orig_output[grad_field] token_emb_fields = self.find_fields(output_spec, types.TokenEmbeddings, types.Tokens) assert len( token_emb_fields) == 1, "Found multiple token embeddings" token_embs = orig_output[token_emb_fields[0]] # Identify the token with the largest gradient attribution, # defined as the dot product between the token embedding and gradient # of the output wrt the embedding. assert token_embs.shape[0] == grads.shape[0] token_grad_attrs = np.sum(token_embs * grads, axis=-1) # Get a list of indices of input tokens, sorted by gradient attribution, # highest first. We will flip tokens in this order. sorted_by_grad_attrs = np.argsort(token_grad_attrs)[::-1] for i in range(min(num_examples, len(tokens))): token_id = sorted_by_grad_attrs[i] logging.info( "Selected token: %s (pos=%d) with gradient attribution %f", tokens[token_id], token_id, token_grad_attrs[token_id]) token_grad = grads[token_id] # Take dot product with all word embeddings. Get smallest value. # (We are look for a replacement token that will lower the score # the current class, thereby increasing the chances of a label # flip.) # TODO(lit-team): Can add criteria to the winner e.g. cosine distance. scores = np.dot(embed, token_grad) winner = np.argmin(scores) logging.info( "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)", tokens[token_id], token_id, i, inv_vocab[winner], winner) # Create a new input to the model. # TODO(iftenney, bastings): enforce somewhere that this field has the # same name in the input and output specs. input_token_field = token_field input_text_field = input_spec[input_token_field].parent # pytype: disable=attribute-error new_example = copy.deepcopy(example) modified_tokens = copy.copy(tokens) modified_tokens[token_id] = inv_vocab[winner] new_example[input_token_field] = modified_tokens # TODO(iftenney, bastings): call a model-provided detokenizer here? # Though in general tokenization isn't invertible and it's possible for # HotFlip to produce wordpiece sequences that don't correspond to any # input string. new_example[input_text_field] = " ".join(modified_tokens) # Predict a new label for this example. new_output = list(model.predict([new_example]))[0] # Update label if multi-class prediction. # TODO(lit-dev): provide a general system for handling labels on # generated examples. probabilities = new_output[pred_key] new_prediction = np.argmax(probabilities) label_key = cast(types.MulticlassPreds, output_spec[pred_key]).parent label_names = cast(types.MulticlassPreds, output_spec[pred_key]).vocab new_label = label_names[new_prediction] new_example[label_key] = new_label logging.info("Updated example with new label: %s", new_label) if new_prediction != orig_prediction: # Hotflip found new_examples.append(new_example) else: # We make new_example as our base example and continue with more # token flips. example = new_example tokens = modified_tokens return new_examples
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Identify minimal sets of token flips that alter the prediction.""" del dataset # Unused. config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) tokens_to_ignore = config.get(TOKENS_TO_IGNORE_KEY, TOKENS_TO_IGNORE_DEFAULT) pred_key = config.get(PREDICTION_KEY, "") regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) assert model is not None, "Please provide a model for this generator." input_spec = model.input_spec() output_spec = model.output_spec() assert pred_key, "Please provide the prediction key" assert pred_key in output_spec, "Invalid prediction key" is_regression = False if isinstance(output_spec[pred_key], types.RegressionScore): is_regression = True else: assert isinstance(output_spec[pred_key], types.MulticlassPreds), ( "Only classification or regression models are supported") logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) # Get model outputs. orig_output = list(model.predict([example]))[0] # Check config for selected fields. selected_fields = list(config.get(FIELDS_TO_HOTFLIP_KEY, [])) if not selected_fields: return [] # Get tokens (corresponding to each text input field) and corresponding # gradients. tokens_and_gradients = self._get_tokens_and_gradients( input_spec, output_spec, orig_output, selected_fields) assert tokens_and_gradients, ( "No token fields found. Cannot use HotFlip. :-(") # Copy tokens into input example. example = copy.deepcopy(example) for token_field, v in tokens_and_gradients.items(): tokens, _ = v example[token_field] = tokens inv_vocab, embedding_matrix = model.get_embedding_table() assert len(inv_vocab) == embedding_matrix.shape[0], ( "Vocab/embeddings size mismatch.") successful_cfs = [] # TODO(lit-team): use only 1 sequence as input (configurable in UI). # TODO(lit-team): Refactor the following code so that it's not so deeply # nested (and easier to track loop state). for token_field, v in tokens_and_gradients.items(): tokens, grads = v text_field = input_spec[token_field].parent # pytype: disable=attribute-error logging.info("Identifying Hotflips for input field: %s", str(text_field)) direction = -1 if is_regression: # We want the replacements to increase the prediction score if the # original score is below the threshold, and decrease otherwise. direction = (1 if orig_output[pred_key] <= regression_thresh else -1) replacement_tokens = self._get_replacement_tokens( embedding_matrix, inv_vocab, grads, direction) successful_positions = [] for token_idxs in self._gen_token_idxs_to_flip( tokens, grads, max_flips, tokens_to_ignore): if len(successful_cfs) >= num_examples: return successful_cfs # If a subset of the set of tokens have already been successful in # obtaining a flip, we continue. This ensures that we only consider # sets of token flips that are minimal. if self._subset_exists(set(token_idxs), successful_positions): continue # Create counterfactual. cf = self._create_cf(example, token_field, text_field, tokens, token_idxs, replacement_tokens) # Obtain model prediction. cf_output = list(model.predict([cf]))[0] if cf_utils.is_prediction_flip(cf_output, orig_output, output_spec, pred_key, regression_thresh): # Prediciton flip found! cf_utils.update_prediction(cf, cf_output, output_spec, pred_key) successful_cfs.append(cf) successful_positions.append(set(token_idxs)) return successful_cfs
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: # Perform validation and retrieve configuration. if not model: raise ValueError('Please provide a model for this generator.') config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) pred_key = config.get(PREDICTION_KEY, '') regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) dataset_name = config.get('dataset_name') if not dataset_name: raise ValueError('The dataset name must be in the config.') output_spec = model.output_spec() if not pred_key: raise ValueError('Please provide the prediction key.') if pred_key not in output_spec: raise ValueError('Invalid prediction key.') if (not (isinstance(output_spec[pred_key], lit_types.MulticlassPreds) or isinstance(output_spec[pred_key], lit_types.RegressionScore))): raise ValueError( 'Only classification and regression models are supported') # Calculate dataset statistics if it has never been calculated. The # statistics include such information as 'standard deviation' for scalar # features and probabilities for categorical features. if dataset_name not in self._datasets_stats: self._calculate_stats(dataset, dataset_name) # Find predicted class of the original example. original_pred = list(model.predict([example]))[0] # Find dataset examples that are flips. filtered_examples = self._filter_ds_examples( dataset=dataset, dataset_name=dataset_name, model=model, reference_output=original_pred, pred_key=pred_key, regression_thresh=regression_thresh) supported_field_names = self._find_all_fields_to_consider( ds_spec=dataset.spec(), model_input_spec=model.input_spec(), example=example) candidates: List[JsonDict] = [] # Iterate through all possible feature combinations. combs = utils.find_all_combinations(supported_field_names, 1, max_flips) for comb in combs: # Sort all dataset examples with respect to the given combination. sorted_examples = self._sort_and_filter_examples( examples=filtered_examples, ref_example=example, fields=comb, dataset=dataset, dataset_name=dataset_name) if not sorted_examples: continue # As an optimization trick, check whether the farthest example is a flip. # If it is not a flip then skip the current combination of features. # This optimization makes the minimum set guarantees weaker but # significantly improves the search speed. flip = self._find_hot_flip(ref_example=example, ds_example=sorted_examples[-1], features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=False, regression_threshold=regression_thresh) if not flip: logging.info('Skipped combination %s', comb) continue # Iterate through the sorted examples until the first flip is found. # TODO(b/204200758): improve performance by batching the predict requests. for ds_example in sorted_examples: flip = self._find_hot_flip( ref_example=example, ds_example=ds_example, features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=True, regression_threshold=regression_thresh) if flip: self._add_if_not_strictly_worse(example=flip, other_examples=candidates, ref_example=example, dataset=dataset, dataset_name=dataset_name, model=model) break if len(candidates) >= num_examples: break # Calculate distances for the found hot flips. candidate_tuples = [] for flip_example in candidates: distance, diff_fields = self._calculate_L1_distance( example_1=example, example_2=flip_example, dataset=dataset, dataset_name=dataset_name, model=model) if distance > 0: candidate_tuples.append((distance, diff_fields, flip_example)) # Order the dataset entries based on the distance to the given example. candidate_tuples.sort(key=lambda e: e[0]) if len(candidate_tuples) > num_examples: candidate_tuples = candidate_tuples[0:num_examples] # e[2] contains the hot-flip examples in the distances list of tuples. return [e[2] for e in candidate_tuples]