def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find gradient fields to interpret output_spec = model.output_spec() grad_fields = self.find_fields(output_spec) logging.info('Found fields for gradient attribution: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) assert len(model_outputs) == len(inputs) all_results = [] for o in model_outputs: # Dict[field name -> interpretations] result = {} for grad_field in grad_fields: token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = o[token_field] scores = self._interpret(o[grad_field], tokens) result[grad_field] = dtypes.SalienceMap(tokens, scores) all_results.append(result) return all_results
def validate_t5_model(model: lit_model.Model) -> lit_model.Model: """Validate that a given model looks like a T5 model. This checks the model spec at runtime; it is intended to be used before server start, such as in the __init__() method of a wrapper class. Args: model: a LIT model Returns: model: the same model Raises: AssertionError: if the model's spec does not match that expected for a T5 model. """ # Check inputs ispec = model.input_spec() assert "input_text" in ispec assert isinstance(ispec["input_text"], lit_types.TextSegment) if "target_text" in ispec: assert isinstance(ispec["target_text"], lit_types.TextSegment) # Check outputs ospec = model.output_spec() assert "output_text" in ospec assert isinstance( ospec["output_text"], (lit_types.GeneratedText, lit_types.GeneratedTextCandidates)) assert ospec["output_text"].parent == "target_text" return model
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find gradient fields to interpret input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(input_spec, output_spec) logging.info('Found fields for integrated gradients: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) all_results = [] for model_output, model_input in zip(model_outputs, inputs): result = self.get_salience_result(model_input, model, model_output, grad_fields) all_results.append(result) return all_results
def run_with_metadata( self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[JsonDict]: """Run this component, given a model and input(s). Args: indexed_inputs: Inputs to cluster. model: Model that provides salience maps. dataset: Dataset to compute salience maps for. model_outputs: Precomputed model outputs. config: Config for clustering and salience computation Returns: Dict with 2 keys: `CLUSTER_ID_KEY`: Contains the cluster assignments. One cluster id per dataset example. `REPRESENTATION_KEY`: Contains the representations of all examples in the dataset that were used in the clustering. """ config = config or {} # Find gradient fields to interpret grad_fields = self.find_fields(model.output_spec()) token_saliencies = self.salience_mappers[ config['salience_mapper']].run_with_metadata( indexed_inputs, model, dataset, model_outputs, config) if not token_saliencies: return None vocab = self._build_vocab(token_saliencies) representations = self._compute_fixed_length_representation( token_saliencies, vocab) cluster_ids = {} grad_field_to_representations = {} for grad_field in grad_fields: weight_matrix = np.vstack(representation[grad_field] for representation in representations) self.kmeans[grad_field] = cluster.KMeans(n_clusters=config.get( N_CLUSTERS_KEY, self.config_spec()[N_CLUSTERS_KEY].default)) cluster_ids[grad_field] = self.kmeans[grad_field].fit_predict( weight_matrix).tolist() grad_field_to_representations[grad_field] = weight_matrix return { CLUSTER_ID_KEY: cluster_ids, REPRESENTATION_KEY: grad_field_to_representations }
def _filter_ds_examples( self, dataset: lit_dataset.IndexedDataset, dataset_name: Text, model: lit_model.Model, reference_output: JsonDict, pred_key: Text, regression_thresh: Optional[float] = None) -> List[JsonDict]: """Reads all dataset examples and returns only those that are flips.""" if not isinstance(dataset, lit_dataset.IndexedDataset): raise ValueError( 'Only indexed datasets are currently supported by the TabularMTC' 'generator.') indexed_examples = list(dataset.indexed_examples) filtered_examples = [] preds = model.predict_with_metadata(indexed_examples, dataset_name=dataset_name) # Find all DS examples that are flips with respect to the reference example. for indexed_example, pred in zip(indexed_examples, preds): flip = cf_utils.is_prediction_flip( cf_output=pred, orig_output=reference_output, output_spec=model.output_spec(), pred_key=pred_key, regression_thresh=regression_thresh) if flip: candidate_example = indexed_example['data'].copy() self._find_dataset_parent_and_set( model_output_spec=model.output_spec(), pred_key=pred_key, dataset_spec=dataset.spec(), example=candidate_example, predicted_value=pred[pred_key]) filtered_examples.append(candidate_example) return filtered_examples
def _is_flip( self, model: lit_model.Model, cf_example: JsonDict, orig_output: JsonDict, pred_key: Text, regression_thresh: Optional[float] = None) -> Tuple[bool, Any]: cf_output = list(model.predict([cf_example]))[0] feature_predicted_value = cf_output[pred_key] return cf_utils.is_prediction_flip( cf_output=cf_output, orig_output=orig_output, output_spec=model.output_spec(), pred_key=pred_key, regression_thresh=regression_thresh), feature_predicted_value
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" config = config or {} class_to_explain = config.get(CLASS_KEY, self.config_spec()[CLASS_KEY].default) interpolation_steps = int( config.get(INTERPOLATION_KEY, self.config_spec()[INTERPOLATION_KEY].default)) normalization = config.get( NORMALIZATION_KEY, self.config_spec()[NORMALIZATION_KEY].default) # Find gradient fields to interpret input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(input_spec, output_spec) logging.info('Found fields for integrated gradients: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) all_results = [] for model_output, model_input in zip(model_outputs, inputs): result = self.get_salience_result(model_input, model, interpolation_steps, normalization, class_to_explain, model_output, grad_fields) all_results.append(result) return all_results
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None, num_examples: int = 1) -> List[JsonDict]: """Use gradient to find/substitute the token with largest impact on loss.""" del dataset # Unused. assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) # Find gradient fields to use for HotFlip input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(output_spec) logging.info("Found gradient fields for HotFlip use: %s", str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test logging.info("No gradient fields found. Cannot use HotFlip. :-(") return [] # Cannot generate examples without gradients. # Get model outputs. logging.info( "Performing a forward/backward pass on the input example.") model_output = model.predict_single(example) logging.info(model_output.keys()) # Get model word embeddings and vocab. inv_vocab, embed = model.get_embedding_table() assert len( inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch." logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab), embed.shape) # Perform a flip in each sequence for which we have gradients (separately). # Each sequence may give rise to multiple new examples, depending on how # many words we flip. # TODO(lit-team): make configurable how many new examples are desired. # TODO(lit-team): use only 1 sequence as input (configurable in UI). new_examples = [] for grad_field in grad_fields: # Get the tokens and their gradient vectors. token_field = output_spec[grad_field].align # pytype: disable=attribute-error tokens = model_output[token_field] grads = model_output[grad_field] # Identify the token with the largest gradient norm. # TODO(lit-team): consider normalizing across all grad fields or just # across each one individually. grad_norm = np.linalg.norm(grads, axis=1) grad_norm = grad_norm / np.sum( grad_norm) # Match grad attribution value. # Get a list of indices of input tokens, sorted by norm, highest first. sorted_by_grad_norm = np.argsort(grad_norm)[::-1] for i in range(min(num_examples, len(tokens))): token_id = sorted_by_grad_norm[i] logging.info( "Selected token: %s (pos=%d) with gradient norm %f", tokens[token_id], token_id, grad_norm[token_id]) token_grad = grads[token_id] # Take dot product with all word embeddings. Get largest value. scores = np.dot(embed, token_grad) # TODO(lit-team): Can add criteria to the winner e.g. cosine distance. winner = np.argmax(scores) logging.info( "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)", tokens[token_id], token_id, i, inv_vocab[winner], winner) # Create a new input to the model. # TODO(iftenney, bastings): enforce somewhere that this field has the # same name in the input and output specs. input_token_field = token_field input_text_field = input_spec[input_token_field].parent # pytype: disable=attribute-error new_example = copy.deepcopy(example) modified_tokens = copy.copy(tokens) modified_tokens[token_id] = inv_vocab[winner] new_example[input_token_field] = modified_tokens # TODO(iftenney, bastings): call a model-provided detokenizer here? # Though in general tokenization isn't invertible and it's possible for # HotFlip to produce wordpiece sequences that don't correspond to any # input string. new_example[input_text_field] = " ".join(modified_tokens) # Predict a new label for this example. new_output = model.predict_single(new_example) # Update label if multi-class prediction. # TODO(lit-dev): provide a general system for handling labels on # generated examples. for pred_key, pred_type in model.output_spec().items(): if isinstance(pred_type, types.MulticlassPreds): probabilities = new_output[pred_key] prediction = np.argmax(probabilities) label_key = output_spec[pred_key].parent label_names = output_spec[pred_key].vocab new_label = label_names[prediction] new_example[label_key] = new_label logging.info("Updated example with new label: %s", new_label) new_examples.append(new_example) return new_examples
def run_with_metadata( self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Runs the TCAV method given the params in the inputs and config. Args: indexed_inputs: all examples in the dataset, in the indexed input format. model: the model being explained. dataset: the dataset which the current examples belong to. model_outputs: optional model outputs from calling model.predict(inputs). config: a config which should specify: { 'concept_set_ids': [list of ids to use in concept set] 'class_to_explain': [gradient class to explain], 'grad_layer': [the Gradient field key of the layer to explain], 'random_state': [an optional seed to make outputs deterministic] 'dataset_name': [the name of the dataset (used for caching)] 'test_size': [Percentage of the example set to use in the LM test set] 'negative_set_ids': [optional list of ids to use as negative set] } Returns: A JsonDict containing the TCAV scores, directional derivatives, statistical test p-values, and LM accuracies. """ config = TCAVConfig(**config) # TODO(b/171513556): get these from the Dataset object once indices are # available there. dataset_examples = indexed_inputs # Get this layer's output spec keys for gradients and embeddings. grad_layer = config.grad_layer output_spec = model.output_spec() emb_layer = cast(types.Gradients, output_spec[grad_layer]).grad_for # Get the class that the gradients were computed for. grad_class_key = cast(types.Gradients, output_spec[grad_layer]).grad_target_field_key ids_set = set(config.concept_set_ids) concept_set = [ex for ex in dataset_examples if ex['id'] in ids_set] non_concept_set = [ ex for ex in dataset_examples if ex['id'] not in ids_set ] # Get outputs using model.predict(). dataset_outputs = list( model.predict_with_metadata(dataset_examples, dataset_name=config.dataset_name)) if config.negative_set_ids: negative_ids_set = set(config.negative_set_ids) negative_set = [ ex for ex in dataset_examples if ex['id'] in negative_ids_set ] return self._run_relative_tcav(grad_layer, emb_layer, grad_class_key, concept_set, negative_set, dataset_outputs, model, config) else: return self._run_default_tcav(grad_layer, emb_layer, grad_class_key, concept_set, non_concept_set, dataset_outputs, model, config)
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Identify minimal sets of token albations that alter the prediction.""" del dataset # Unused. config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_ablations = int( config.get(MAX_ABLATIONS_KEY, MAX_ABLATIONS_DEFAULT)) assert model is not None, "Please provide a model for this generator." input_spec = model.input_spec() pred_key = config.get(PREDICTION_KEY, "") regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) output_spec = model.output_spec() assert pred_key, "Please provide the prediction key" assert pred_key in output_spec, "Invalid prediction key" is_regression = isinstance(output_spec[pred_key], types.RegressionScore) if not is_regression: assert isinstance(output_spec[pred_key], types.MulticlassPreds), ( "Only classification or regression models are supported") logging.info(r"W3lc0m3 t0 Ablatl0nFl1p \o/") logging.info("Original example: %r", example) # Check for fields to ablate. fields_to_ablate = list(config.get(FIELDS_TO_ABLATE_KEY, [])) if not fields_to_ablate: return [] # Get model outputs. orig_output = list(model.predict([example]))[0] loo_scores = self._generate_leave_one_out_ablation_score( example, model, input_spec, output_spec, orig_output, pred_key, fields_to_ablate) if isinstance(output_spec[pred_key], types.RegressionScore): ablation_idxs_generator = self._gen_ablation_idxs( loo_scores, max_ablations, orig_output[pred_key], regression_thresh) else: ablation_idxs_generator = self._gen_ablation_idxs( loo_scores, max_ablations) tokens_map = {} for field in input_spec.keys(): tokens = self._get_tokens(example, input_spec, field) if not tokens: continue tokens_map[field] = tokens successful_cfs = [] successful_positions = [] for ablation_idxs in ablation_idxs_generator: if len(successful_cfs) >= num_examples: return successful_cfs # If a subset of the set of tokens have already been successful in # obtaining a flip, we continue. This ensures that we only consider # sets of tokens that are minimal. if self._subset_exists(set(ablation_idxs), successful_positions): continue # Create counterfactual and obtain model prediction. cf = self._create_cf(example, input_spec, ablation_idxs) cf_output = list(model.predict([cf]))[0] # Check if counterfactual results in a prediction flip. if cf_utils.is_prediction_flip(cf_output, orig_output, output_spec, pred_key, regression_thresh): # Prediction flip found! cf_utils.update_prediction(cf, cf_output, output_spec, pred_key) cf[ABLATED_TOKENS_KEY] = str([ f"{field}[{tokens_map[field][idx]}]" for field, idx in ablation_idxs ]) successful_cfs.append(cf) successful_positions.append(set(ablation_idxs)) return successful_cfs
def is_compatible(self, model: lit_model.Model) -> bool: input_spec = model.input_spec() output_spec = model.output_spec() return find_supported_fields(input_spec, output_spec) is not None
def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, model_output: JsonDict, grad_fields: List[Text]): result = {} output_spec = model.output_spec() # We ensure that the embedding and gradient class fields are present in the # model's input spec in find_fields(). embeddings_fields = [ cast(types.TokenGradients, output_spec[grad_field]).grad_for for grad_field in grad_fields ] # The gradient class input is used to specify the target class of the # gradient calculation (if unspecified, this option defaults to the argmax, # which could flip between interpolated inputs). grad_class_key = cast(types.TokenGradients, output_spec[grad_fields[0]]).grad_target # TODO(b/168042999): Add option to specify the class to explain in the UI. grad_class = model_output[grad_class_key] interpolated_inputs = {} all_embeddings = [] all_baselines = [] for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] embeddings = np.array(model_output[embed_field]) all_embeddings.append(embeddings) # Starts with baseline of zeros. <float32>[num_tokens, emb_size] baseline = self.get_baseline(embeddings) all_baselines.append(baseline) # Get interpolated inputs from baseline to original embedding. # <float32>[interpolation_steps, num_tokens, emb_size] interpolated_inputs[embed_field] = self.get_interpolated_inputs( baseline, embeddings, self.interpolation_steps) # Create model inputs and populate embedding field(s). inputs_with_embeds = [] for i in range(self.interpolation_steps): input_copy = model_input.copy() # Interpolates embeddings for all inputs simultaneously. for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] input_copy[embed_field] = interpolated_inputs[embed_field][i] input_copy[grad_class_key] = grad_class inputs_with_embeds.append(input_copy) embed_outputs = model.predict(inputs_with_embeds) # Create list with concatenated gradients for each interpolate input. gradients = [] for o in embed_outputs: # <float32>[total_num_tokens, emb_size] interp_gradients = np.concatenate( [o[field] for field in grad_fields]) gradients.append(interp_gradients) # <float32>[interpolation_steps, total_num_tokens, emb_size] path_gradients = np.stack(gradients, axis=0) # Calculate integral # <float32>[total_num_tokens, emb_size] integral = self.estimate_integral(path_gradients) # <float32>[total_num_tokens, emb_size] concat_embeddings = np.concatenate(all_embeddings) # <float32>[total_num_tokens, emb_size] concat_baseline = np.concatenate(all_baselines) # <float32>[total_num_tokens, emb_size] integrated_gradients = integral * (np.array(concat_embeddings) - np.array(concat_baseline)) # Dot product of integral values and (embeddings - baseline). # <float32>[total_num_tokens] attributions = np.sum(integrated_gradients, axis=-1) # TODO(b/168042999): Make normalization customizable in the UI. # <float32>[total_num_tokens] scores = citrus_utils.normalize_scores(attributions) for grad_field in grad_fields: # Format as salience map result. token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = model_output[token_field] # Only use the scores that correspond to the tokens in this grad_field. # The gradients for all input embeddings were concatenated in the order # of the grad fields, so they can be sliced out in the same order. sliced_scores = scores[:len( tokens)] # <float32>[num_tokens in field] scores = scores[len(tokens):] # <float32>[num_remaining_tokens] assert len(tokens) == len(sliced_scores) result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores) return result
def is_compatible(self, model: lit_model.Model): compatible_fields = self.find_fields(model.output_spec()) return len(compatible_fields)
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT kernel_width = int( config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT mask_string = config[MASK_KEY] if config else MASK_DEFAULT num_samples = int( config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT seed = config[SEED_KEY] if config else SEED_DEFAULT # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial(_predict_fn, model=model, original_example=input_, pred_key=pred_key) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain= class_to_explain, # Index of the class to explain. num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def find_fields(self, model: lit_model.Model) -> Dict[str, List[str]]: src_fields = utils.find_spec_keys(model.input_spec(), types.TextSegment) gen_fields = utils.find_spec_keys( model.output_spec(), (types.GeneratedText, types.GeneratedTextCandidates)) return {f: src_fields for f in gen_fields}
def _find_closer_flip_using_interpolation( self, ref_example: JsonDict, known_flip: JsonDict, target_pred: JsonDict, pred_key: Text, model: lit_model.Model, dataset: lit_dataset.Dataset, regression_threshold: Optional[float] = None, max_attempts: int = 4) -> Optional[JsonDict]: """Looks for the decision boundary between two examples using interpolation. The method searches for a flip that is closer to the `target example` than `known_flip`. The method performs the binary search by interpolating scalar values. Args: ref_example: an example for which the flip is searched. known_flip: an example that represents a known flip. target_pred: the model prediction at `ref_example`. pred_key: the named of the field inside `target_pred` that holds the prediction value. model: model to use for running predictions. dataset: dataset that contains `known_flip`. regression_threshold: threshold to use for regression models. max_attempts: number of binary search attempts. Returns: The counterfactual (flip) if found; 'None' otherwise. """ min_alpha = 0.0 max_alpha = 1.0 closest_flip = None input_spec = model.input_spec() has_scalar = False for _ in range(max_attempts): # Interpolate the scalar values using binary search. current_alpha = (min_alpha + max_alpha) / 2 candidate = known_flip.copy() for field in ref_example: if (field in candidate and field in input_spec and isinstance(input_spec[field], lit_types.Scalar) and candidate[field] is not None and ref_example[field] is not None): candidate[field] = known_flip[field] * ( 1 - current_alpha) + ref_example[field] * current_alpha has_scalar = True # The interpolation makes sense only for scalar values. If there are no # scalar fields that can be interpolated then terminate the search. if not has_scalar: return None flip, predicted_value = self._is_flip( model=model, cf_example=candidate, orig_output=target_pred, pred_key=pred_key, regression_thresh=regression_threshold) if flip: self._find_dataset_parent_and_set( model_output_spec=model.output_spec(), pred_key=pred_key, dataset_spec=dataset.spec(), example=candidate, predicted_value=predicted_value) closest_flip = candidate min_alpha = current_alpha else: max_alpha = current_alpha return closest_flip
def _find_hot_flip( self, ref_example: JsonDict, ds_example: JsonDict, features_to_consider: List[Text], model: lit_model.Model, target_pred: JsonDict, pred_key: Text, dataset: lit_dataset.Dataset, interpolate: bool, regression_threshold: Optional[float] = None, ) -> Optional[JsonDict]: """Finds a hot-flip example for a given target example and DS example. Args: ref_example: target example for which the counterfactuals should be found. ds_example: a dataset example that should be used as a starting point for the search. features_to_consider: the list of feature keys that can be changed during the search. model: model to use for getting predictions. target_pred: model prediction that corresponds to `ref_example`. pred_key: the name of the field in model predictions that contains the prediction value for the counterfactual search. dataset: a dataset object that contains `ds_example`. interpolate: if True, the method tries to find a closer counterfactual using interpolation. regression_threshold: the threshold to use if `model` is a regression model. This parameter is ignored for classification models. Returns: A hot-flip counterfactual that satisfy the criteria. """ # All features other than `features_to_consider` should be assigned the # value of the target example. candidate_example = ds_example.copy() for field_name in ref_example: if (field_name not in features_to_consider and field_name in model.input_spec()): candidate_example[field_name] = ref_example[field_name] flip, predicted_value = self._is_flip( model=model, cf_example=candidate_example, orig_output=target_pred, pred_key=pred_key, regression_thresh=regression_threshold) if not flip: return None # Find closest flip by moving scalar values closer to the target. closest_flip = None if interpolate: closest_flip = self._find_closer_flip_using_interpolation( ref_example, candidate_example, target_pred, pred_key, model, dataset, regression_threshold) # If we found a closer flip through interpolation then use it, # otherwise use the previously found flip. if closest_flip is not None: return closest_flip else: self._find_dataset_parent_and_set( model_output_spec=model.output_spec(), pred_key=pred_key, dataset_spec=dataset.spec(), example=candidate_example, predicted_value=predicted_value) return candidate_example
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: # Perform validation and retrieve configuration. if not model: raise ValueError('Please provide a model for this generator.') config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) pred_key = config.get(PREDICTION_KEY, '') regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) dataset_name = config.get('dataset_name') if not dataset_name: raise ValueError('The dataset name must be in the config.') output_spec = model.output_spec() if not pred_key: raise ValueError('Please provide the prediction key.') if pred_key not in output_spec: raise ValueError('Invalid prediction key.') if (not (isinstance(output_spec[pred_key], lit_types.MulticlassPreds) or isinstance(output_spec[pred_key], lit_types.RegressionScore))): raise ValueError( 'Only classification and regression models are supported') # Calculate dataset statistics if it has never been calculated. The # statistics include such information as 'standard deviation' for scalar # features and probabilities for categorical features. if dataset_name not in self._datasets_stats: self._calculate_stats(dataset, dataset_name) # Find predicted class of the original example. original_pred = list(model.predict([example]))[0] # Find dataset examples that are flips. filtered_examples = self._filter_ds_examples( dataset=dataset, dataset_name=dataset_name, model=model, reference_output=original_pred, pred_key=pred_key, regression_thresh=regression_thresh) supported_field_names = self._find_all_fields_to_consider( ds_spec=dataset.spec(), model_input_spec=model.input_spec(), example=example) candidates: List[JsonDict] = [] # Iterate through all possible feature combinations. combs = utils.find_all_combinations(supported_field_names, 1, max_flips) for comb in combs: # Sort all dataset examples with respect to the given combination. sorted_examples = self._sort_and_filter_examples( examples=filtered_examples, ref_example=example, fields=comb, dataset=dataset, dataset_name=dataset_name) if not sorted_examples: continue # As an optimization trick, check whether the farthest example is a flip. # If it is not a flip then skip the current combination of features. # This optimization makes the minimum set guarantees weaker but # significantly improves the search speed. flip = self._find_hot_flip(ref_example=example, ds_example=sorted_examples[-1], features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=False, regression_threshold=regression_thresh) if not flip: logging.info('Skipped combination %s', comb) continue # Iterate through the sorted examples until the first flip is found. # TODO(b/204200758): improve performance by batching the predict requests. for ds_example in sorted_examples: flip = self._find_hot_flip( ref_example=example, ds_example=ds_example, features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=True, regression_threshold=regression_thresh) if flip: self._add_if_not_strictly_worse(example=flip, other_examples=candidates, ref_example=example, dataset=dataset, dataset_name=dataset_name, model=model) break if len(candidates) >= num_examples: break # Calculate distances for the found hot flips. candidate_tuples = [] for flip_example in candidates: distance, diff_fields = self._calculate_L1_distance( example_1=example, example_2=flip_example, dataset=dataset, dataset_name=dataset_name, model=model) if distance > 0: candidate_tuples.append((distance, diff_fields, flip_example)) # Order the dataset entries based on the distance to the given example. candidate_tuples.sort(key=lambda e: e[0]) if len(candidate_tuples) > num_examples: candidate_tuples = candidate_tuples[0:num_examples] # e[2] contains the hot-flip examples in the distances list of tuples. return [e[2] for e in candidate_tuples]
def is_compatible(self, model: lit_model.Model): text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) return len(text_keys) and len(pred_keys)
def find_fields(self, model: lit_model.Model) -> List[str]: sal_keys = utils.find_spec_keys( model.output_spec(), (types.FeatureSalience, types.ImageSalience, types.TokenSalience, types.SequenceSalience)) return sal_keys
def run(self, inputs: List[types.JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[types.JsonDict]] = None, config: Optional[types.JsonDict] = None): """Create PDP chart info using provided inputs. Args: inputs: sequence of inputs, following model.input_spec() model: optional model to use to generate new examples. dataset: dataset which the current examples belong to. model_outputs: optional precomputed model outputs config: optional runtime config. Returns: a dict of alternate feature values to model outputs. The model outputs will be a number for regression models and a list of numbers for multiclass models. """ pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('PDP did not find any supported output fields.') return None assert 'feature' in config, 'No feature to test provided' feature = config['feature'] provided_range = config['range'] if 'range' in config else [] edited_outputs = {} for pred_key in pred_keys: edited_outputs[pred_key] = {} # If a range was provided, use that to create the possible values. vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10) if len(provided_range) == 2 else self.get_vals_to_test( feature, dataset)) # If no specific inputs provided, use the entire dataset. inputs_to_use = inputs if inputs else dataset.examples # For each alternate value for a given feature. for new_val in vals_to_test: # Create copies of all provided inputs with the value replaced. edited_inputs = [] for inp in inputs_to_use: edited_input = copy.deepcopy(inp) edited_input[feature] = new_val edited_inputs.append(edited_input) # Run prediction on the altered inputs. outputs = list(model.predict(edited_inputs)) # Store the mean of the prediction for the alternate value. for pred_key in pred_keys: numeric = isinstance(model.output_spec()[pred_key], types.RegressionScore) if numeric: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs]) else: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs], axis=0) return edited_outputs
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Identify minimal sets of token flips that alter the prediction.""" del dataset # Unused. config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) tokens_to_ignore = config.get(TOKENS_TO_IGNORE_KEY, TOKENS_TO_IGNORE_DEFAULT) pred_key = config.get(PREDICTION_KEY, "") regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) assert model is not None, "Please provide a model for this generator." input_spec = model.input_spec() output_spec = model.output_spec() assert pred_key, "Please provide the prediction key" assert pred_key in output_spec, "Invalid prediction key" is_regression = False if isinstance(output_spec[pred_key], types.RegressionScore): is_regression = True else: assert isinstance(output_spec[pred_key], types.MulticlassPreds), ( "Only classification or regression models are supported") logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) # Get model outputs. orig_output = list(model.predict([example]))[0] # Check config for selected fields. selected_fields = list(config.get(FIELDS_TO_HOTFLIP_KEY, [])) if not selected_fields: return [] # Get tokens (corresponding to each text input field) and corresponding # gradients. tokens_and_gradients = self._get_tokens_and_gradients( input_spec, output_spec, orig_output, selected_fields) assert tokens_and_gradients, ( "No token fields found. Cannot use HotFlip. :-(") # Copy tokens into input example. example = copy.deepcopy(example) for token_field, v in tokens_and_gradients.items(): tokens, _ = v example[token_field] = tokens inv_vocab, embedding_matrix = model.get_embedding_table() assert len(inv_vocab) == embedding_matrix.shape[0], ( "Vocab/embeddings size mismatch.") successful_cfs = [] # TODO(lit-team): use only 1 sequence as input (configurable in UI). # TODO(lit-team): Refactor the following code so that it's not so deeply # nested (and easier to track loop state). for token_field, v in tokens_and_gradients.items(): tokens, grads = v text_field = input_spec[token_field].parent # pytype: disable=attribute-error logging.info("Identifying Hotflips for input field: %s", str(text_field)) direction = -1 if is_regression: # We want the replacements to increase the prediction score if the # original score is below the threshold, and decrease otherwise. direction = (1 if orig_output[pred_key] <= regression_thresh else -1) replacement_tokens = self._get_replacement_tokens( embedding_matrix, inv_vocab, grads, direction) successful_positions = [] for token_idxs in self._gen_token_idxs_to_flip( tokens, grads, max_flips, tokens_to_ignore): if len(successful_cfs) >= num_examples: return successful_cfs # If a subset of the set of tokens have already been successful in # obtaining a flip, we continue. This ensures that we only consider # sets of token flips that are minimal. if self._subset_exists(set(token_idxs), successful_positions): continue # Create counterfactual. cf = self._create_cf(example, token_field, text_field, tokens, token_idxs, replacement_tokens) # Obtain model prediction. cf_output = list(model.predict([cf]))[0] if cf_utils.is_prediction_flip(cf_output, orig_output, output_spec, pred_key, regression_thresh): # Prediciton flip found! cf_utils.update_prediction(cf, cf_output, output_spec, pred_key) successful_cfs.append(cf) successful_positions.append(set(token_idxs)) return successful_cfs
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None, num_examples: int = 1) -> List[JsonDict]: """Use gradient to find/substitute the token with largest impact on loss.""" # TODO(lit-team): This function is quite long. Consider breaking it # into small functions. del dataset # Unused. assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) # Find classification prediciton key. pred_keys = self.find_fields(model.output_spec(), types.MulticlassPreds, None) if len(pred_keys) == 0: # pylint: disable=g-explicit-length-test # TODO(ataly): Add support for regression models. logging.warning("The model does not have a classification head." "Cannot use HotFlip. :-(") return [] # Cannot generate examples. if len(pred_keys) > 1: # TODO(ataly): Use a config argument when there are multiple prediction # heads. logging.warning("Multiple classification heads found." "Cannot use HotFlip. :-(") return [] # Cannot generate examples. pred_key = pred_keys[0] # Find gradient fields to use for HotFlip input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(output_spec, types.TokenGradients, types.Tokens) logging.info("Found gradient fields for HotFlip use: %s", str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test logging.info("No gradient fields found. Cannot use HotFlip. :-(") return [] # Cannot generate examples without gradients. # Get model outputs. logging.info( "Performing a forward/backward pass on the input example.") orig_output = list(model.predict([example]))[0] logging.info(orig_output.keys()) # Get model word embeddings and vocab. inv_vocab, embed = model.get_embedding_table() assert len( inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch." logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab), embed.shape) # Get original prediction class orig_probabilities = orig_output[pred_key] orig_prediction = np.argmax(orig_probabilities) # Perform a flip in each sequence for which we have gradients (separately). # Each sequence may give rise to multiple new examples, depending on how # many words we flip. # TODO(lit-team): make configurable how many new examples are desired. # TODO(lit-team): use only 1 sequence as input (configurable in UI). new_examples = [] for grad_field in grad_fields: # Get the tokens and their gradient vectors. token_field = output_spec[grad_field].align # pytype: disable=attribute-error tokens = orig_output[token_field] grads = orig_output[grad_field] token_emb_fields = self.find_fields(output_spec, types.TokenEmbeddings, types.Tokens) assert len( token_emb_fields) == 1, "Found multiple token embeddings" token_embs = orig_output[token_emb_fields[0]] # Identify the token with the largest gradient attribution, # defined as the dot product between the token embedding and gradient # of the output wrt the embedding. assert token_embs.shape[0] == grads.shape[0] token_grad_attrs = np.sum(token_embs * grads, axis=-1) # Get a list of indices of input tokens, sorted by gradient attribution, # highest first. We will flip tokens in this order. sorted_by_grad_attrs = np.argsort(token_grad_attrs)[::-1] for i in range(min(num_examples, len(tokens))): token_id = sorted_by_grad_attrs[i] logging.info( "Selected token: %s (pos=%d) with gradient attribution %f", tokens[token_id], token_id, token_grad_attrs[token_id]) token_grad = grads[token_id] # Take dot product with all word embeddings. Get smallest value. # (We are look for a replacement token that will lower the score # the current class, thereby increasing the chances of a label # flip.) # TODO(lit-team): Can add criteria to the winner e.g. cosine distance. scores = np.dot(embed, token_grad) winner = np.argmin(scores) logging.info( "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)", tokens[token_id], token_id, i, inv_vocab[winner], winner) # Create a new input to the model. # TODO(iftenney, bastings): enforce somewhere that this field has the # same name in the input and output specs. input_token_field = token_field input_text_field = input_spec[input_token_field].parent # pytype: disable=attribute-error new_example = copy.deepcopy(example) modified_tokens = copy.copy(tokens) modified_tokens[token_id] = inv_vocab[winner] new_example[input_token_field] = modified_tokens # TODO(iftenney, bastings): call a model-provided detokenizer here? # Though in general tokenization isn't invertible and it's possible for # HotFlip to produce wordpiece sequences that don't correspond to any # input string. new_example[input_text_field] = " ".join(modified_tokens) # Predict a new label for this example. new_output = list(model.predict([new_example]))[0] # Update label if multi-class prediction. # TODO(lit-dev): provide a general system for handling labels on # generated examples. probabilities = new_output[pred_key] new_prediction = np.argmax(probabilities) label_key = cast(types.MulticlassPreds, output_spec[pred_key]).parent label_names = cast(types.MulticlassPreds, output_spec[pred_key]).vocab new_label = label_names[new_prediction] new_example[label_key] = new_label logging.info("Updated example with new label: %s", new_label) if new_prediction != orig_prediction: # Hotflip found new_examples.append(new_example) else: # We make new_example as our base example and continue with more # token flips. example = new_example tokens = modified_tokens return new_examples
def run_with_metadata( self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Calculates optimal thresholds on the provided data. Args: indexed_inputs: all examples in the dataset, in the indexed input format. model: the model being explained. dataset: the dataset which the current examples belong to. model_outputs: optional model outputs from calling model.predict(inputs). config: a config which should specify TresholderConfig options. Returns: A JsonDict containing the calcuated thresholds """ config = TresholderConfig(**config) if config else TresholderConfig() pred_keys = [] for pred_key, field_spec in model.output_spec().items(): if self.metrics_gen.is_compatible(field_spec) and cast( types.MulticlassPreds, field_spec).parent: pred_keys.append(pred_key) indexed_outputs = { ex['id']: output for (ex, output) in zip(indexed_inputs, model_outputs) } # Try all margins for thresholds from 0 to 1, by hundreths. margins_to_try = [ self.threshold_to_margin(t) for t in np.linspace(0, 1, 101) ] # Get binary classification metrics for all margins, for the entire # dataset, and also for each facet specified in the config. dataset_results = [] faceted_results = {} # Loop over each margin/threshold to check. for margin in margins_to_try: # Set up an empty config to pass to the metrics generator. metrics_config = {} for pred_key in pred_keys: metrics_config[pred_key] = {'': {'margin': margin}} # Get and store the metrics for the entire dataset for this margin. dataset_results.append( self.metrics_gen.run_with_metadata(indexed_inputs, model, dataset, model_outputs, metrics_config)) # Get and store the metrics for each facet of the dataset for this margin. for facet_key in config.facets: if 'data' not in config.facets[facet_key]: continue if facet_key not in faceted_results: faceted_results[facet_key] = [] faceted_model_outputs = [ indexed_outputs[ex['id']] for ex in config.facets[facet_key]['data'] ] faceted_results[facet_key].append( self.metrics_gen.run_with_metadata( config.facets[facet_key]['data'], model, dataset, faceted_model_outputs, metrics_config)) pred_keys = [result['pred_key'] for result in dataset_results[0]] ret = [] # Find threshold information for each prediction key. for i, pred_key in enumerate(pred_keys): ret.append( self.get_thresholds_for_pred_key(pred_key, i, margins_to_try, dataset_results, faceted_results, config)) return ret
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Runs the component, given a model and input(s).""" input_spec = model.input_spec() output_spec = model.output_spec() if not inputs: return [] # Find all fields required for the interpretation. supported_fields = find_supported_fields(input_spec, output_spec) if supported_fields is None: return grad_field_key = supported_fields.grad_field_key image_field_key = supported_fields.image_field_key grad_target_field_key = supported_fields.grad_target_field_key preds_field_key = supported_fields.preds_field_key multiclass = grad_target_field_key is not None # Determine the shape of gradients by calling the model with a single input # and extracting the shape from the gradient output. model_output = list(model.predict([inputs[0]])) grad_shape = model_output[0][grad_field_key].shape # If it is a multiclass model, find the labels with respect to which we # should compute the gradients. if multiclass: # Get class labels. label_vocab = list( cast(types.MulticlassPreds, output_spec[preds_field_key]).vocab) # Run the model in order to find the gradient target labels. outputs = list(model.predict(inputs)) grad_target_labels = [] for model_input, model_output in zip(inputs, outputs): if model_input.get(grad_target_field_key) is not None: grad_target_labels.append( model_input[grad_target_field_key]) else: max_idx = int(np.argmax(model_output[preds_field_key])) grad_target_labels.append(label_vocab[max_idx]) else: grad_target_labels = [None] * len(inputs) saliency_object = self.get_saliency_object() extra_saliency_params = self.get_extra_saliency_params(config) all_results = [] for example, grad_target_label in zip(inputs, grad_target_labels): result = {} image_str = example[image_field_key] saliency_input = image_utils.convert_image_str_to_array( image_str=image_str, shape=grad_shape) call_model_func = get_call_model_func( model=model, model_input=example, image_field_key=image_field_key, grad_field_key=grad_field_key, grad_target_field_key=grad_target_field_key, grad_target_label=grad_target_label) attribution = self.make_saliency_call( saliency_object=saliency_object, x_value=saliency_input, call_model_function=call_model_func, extra_saliency_params=extra_saliency_params) if attribution.ndim == 3: attribution = attribution.sum(axis=2) viz_params = self.get_visualization_params() overlaid_image = image_utils.overlay_pixel_saliency( image_str, attribution, **viz_params) result[grad_field_key] = image_utils.convert_pil_to_image_str( overlaid_image) all_results.append(result) return all_results
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" config_defaults = {k: v.default for k, v in self.config_spec().items()} config = dict(config_defaults, **(config or {})) # update and return provided_class_to_explain = int(config[CLASS_KEY]) kernel_width = int(config[KERNEL_WIDTH_KEY]) num_samples = int(config[NUM_SAMPLES_KEY]) mask_string = (config[MASK_KEY]) # pylint: disable=g-explicit-bool-comparison seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None # pylint: enable=g-explicit-bool-comparison # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore, types.SparseMultilabelPreds)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = config[TARGET_HEAD_KEY] or pred_keys[0] all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial( _predict_fn, model=model, original_example=input_, pred_key=pred_key, pred_type_info=model.output_spec()[pred_key]) class_to_explain = get_class_to_explain(provided_class_to_explain, model, pred_key, input_) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.TokenSalience( input_string.split(), scores) all_results.append(result) return all_results
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, kernel_width: int = 25, # TODO(lit-dev): make configurable in UI. mask_string: str = '[MASK]', # TODO(lit-dev): make configurable in UI. num_samples: int = 256, # TODO(lit-dev): make configurable in UI. ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys(model.output_spec(), types.MulticlassPreds) if not pred_keys: logging.warning( 'LIME did not find a multi-class predictions field.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key]) label_names = pred_spec.vocab # Create a LIME text explainer instance. explainer = lime_text.LimeTextExplainer( class_names=label_names, split_expression=str.split, kernel_width=kernel_width, mask_string=mask_string, # This is the string used to mask words. bow=False ) # bow=False masks inputs, instead of deleting them entirely. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] logging.info('Explaining: %s', input_string) # Use the number of words as the number of features. num_features = len(input_string.split()) def _predict_proba(strings: List[Text]): """Given raw strings, return probabilities. Used by `explainer`.""" input_examples = [ new_example(input_, text_key, s) for s in strings ] model_outputs = model.predict(input_examples) probs = np.array( [output[pred_key] for output in model_outputs]) return probs # <float32>[len(strings), num_labels] # Perturbs the input string, gets model predictions, fits linear model. explanation = explainer.explain_instance( input_string, _predict_proba, num_features=num_features, num_samples=num_samples) # Turn the LIME explanation into a list following original word order. scores = explanation_to_array(explanation) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def run_with_metadata( self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Runs the TCAV method given the params in the inputs and config. Args: indexed_inputs: all examples in the dataset, in the indexed input format. model: the model being explained. dataset: the dataset which the current examples belong to. model_outputs: optional model outputs from calling model.predict(inputs). config: a config which should specify: { 'concept_set_ids': [list of ids to use in concept set] 'class_to_explain': [gradient class to explain], 'grad_layer': [the Gradient field key of the layer to explain], 'random_state': [an optional seed to make outputs deterministic] } Returns: A JsonDict containing the TCAV scores, directional derivatives, statistical test p-values, and LM accuracies. """ config = TCAVConfig(**config) # TODO(b/171513556): get these from the Dataset object once indices are # available there. dataset_examples = indexed_inputs # Get this layer's output spec keys for gradients and embeddings. grad_layer = config.grad_layer output_spec = model.output_spec() emb_layer = cast(types.Gradients, output_spec[grad_layer]).grad_for # Get the class that the gradients were computed for. grad_class_key = cast(types.Gradients, output_spec[grad_layer]).grad_target ids_set = set(config.concept_set_ids) concept_set = [ex for ex in dataset_examples if ex['id'] in ids_set] non_concept_set = [ex for ex in dataset_examples if ex['id'] not in ids_set] # Get outputs using model.predict(). dataset_outputs = list(model.predict_with_metadata(dataset_examples)) def _subsample(examples, n): return random.sample(examples, n) if n < len(examples) else examples concept_outputs = list(model.predict_with_metadata(concept_set)) non_concept_outputs = list(model.predict_with_metadata(non_concept_set)) concept_results = [] # If there are more concept set examples than non-concept set examples, we # use random splits of the concept examples as the concept set and use the # remainder of the dataset as the comparison set. Otherwise, we use random # splits of the non-concept examples as the comparison set. n = min(len(concept_set), len(non_concept_set)) # If there are an equal number of concept and non-concept examples, we # decrease n by one so that we also sample a different set in each TCAV run. if len(concept_set) == len(non_concept_set): n -= 1 for _ in range(NUM_SPLITS): concept_split_outputs = _subsample(concept_outputs, n) comparison_split_outputs = _subsample(non_concept_outputs, n) concept_results.append(self._run_tcav(concept_split_outputs, comparison_split_outputs, dataset_outputs, config.class_to_explain, emb_layer, grad_layer, grad_class_key, config.test_size, config.random_state)) cav_scores = [res['score'] for res in concept_results] p_val = self.hyp_test(cav_scores) # Get index of CAV result with the highest accuracy. accuracies = [res['accuracy'] for res in concept_results] index = np.argmax(accuracies) # Many CAVS are trained and checked for statistical testing to calculate # the p-value. The values of the first CAV are returned. results = {'result': concept_results[index], 'p_val': p_val} return [results]