def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find gradient fields to interpret output_spec = model.output_spec() grad_fields = self.find_fields(output_spec) logging.info('Found fields for gradient attribution: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) assert len(model_outputs) == len(inputs) all_results = [] for o in model_outputs: # Dict[field name -> interpretations] result = {} for grad_field in grad_fields: token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = o[token_field] scores = self._interpret(o[grad_field], tokens) result[grad_field] = dtypes.SalienceMap(tokens, scores) all_results.append(result) return all_results
def validate_t5_model(model: lit_model.Model) -> lit_model.Model: """Validate that a given model looks like a T5 model. This checks the model spec at runtime; it is intended to be used before server start, such as in the __init__() method of a wrapper class. Args: model: a LIT model Returns: model: the same model Raises: AssertionError: if the model's spec does not match that expected for a T5 model. """ # Check inputs ispec = model.input_spec() assert "input_text" in ispec assert isinstance(ispec["input_text"], lit_types.TextSegment) if "target_text" in ispec: assert isinstance(ispec["target_text"], lit_types.TextSegment) # Check outputs ospec = model.output_spec() assert "output_text" in ospec assert isinstance( ospec["output_text"], (lit_types.GeneratedText, lit_types.GeneratedTextCandidates)) assert ospec["output_text"].parent == "target_text" return model
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find gradient fields to interpret input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(input_spec, output_spec) logging.info('Found fields for integrated gradients: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) all_results = [] for model_output, model_input in zip(model_outputs, inputs): result = self.get_salience_result(model_input, model, model_output, grad_fields) all_results.append(result) return all_results
def _get_results_from_model(model: lit_model.Model, data: lit_dataset.Dataset, notebook: bool) -> List[Dict]: tqdm = notebook_tqdm if notebook else normal_tqdm batch_size = model.max_minibatch_size() results = [] for i in tqdm(range(0, len(data), batch_size), desc='processing batches'): batch_examples = data.examples[i:i + batch_size] results.extend(model.predict_minibatch(batch_examples)) return results
def run_with_metadata( self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Finds the nearest neighbors of the example specified in the config. Args: indexed_inputs: the dataset example to find nearest neighbors for. model: the model being explained. dataset: the dataset which the current examples belong to. model_outputs: optional model outputs from calling model.predict(inputs). config: a config which should specify: { 'num_neighbors': [the number of nearest neighbors to return] 'dataset_name': [the name of the dataset (used for caching)] 'embedding_name': [the name of the embedding field to use] } Returns: A JsonDict containing the a list of num_neighbors nearest neighbors, where each has the example id and distance from the main example. """ config = NearestNeighborsConfig(**config) dataset_outputs = list( model.predict_with_metadata(dataset.indexed_examples, dataset_name=config.dataset_name)) example_outputs = list( model.predict_with_metadata(indexed_inputs, dataset_name=config.dataset_name)) # TODO(lit-dev): Add support for selecting nearest neighbors of a set. if len(example_outputs) != 1: raise ValueError('More than one selected example was passed in.') example_output = example_outputs[0] # <float32>[emb_size] dataset_embs = [ output[config.embedding_name] for output in dataset_outputs ] example_embs = [example_output[config.embedding_name]] distances = distance.cdist(example_embs, dataset_embs)[0] sorted_indices = np.argsort(distances) k = config.num_neighbors k_nearest_neighbors = [{ 'id': dataset.indexed_examples[original_index]['id'], 'nn_distance': distances[original_index] } for original_index in sorted_indices[:k]] return [{'nearest_neighbors': k_nearest_neighbors}]
def _is_flip( self, model: lit_model.Model, cf_example: JsonDict, orig_output: JsonDict, pred_key: Text, regression_thresh: Optional[float] = None) -> Tuple[bool, Any]: cf_output = list(model.predict([cf_example]))[0] feature_predicted_value = cf_output[pred_key] return cf_utils.is_prediction_flip( cf_output=cf_output, orig_output=orig_output, output_spec=model.output_spec(), pred_key=pred_key, regression_thresh=regression_thresh), feature_predicted_value
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" if not inputs: return # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LEMON requires text inputs.') return None logging.info('Found text fields for LEMON attribution: %s', str(text_keys)) pred_key = config['pred_key'] output_probs = np.array([output[pred_key] for output in model_outputs]) # Explain the input given counterfactuals. # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: sentences = [item[text_key] for item in inputs] input_to_prediction = dict(zip(sentences, output_probs)) input_string = sentences[0] counterfactuals = sentences[1:] # Remove duplicate counterfactuals. counterfactuals = list(set(counterfactuals)) logging.info('Explaining: %s', input_string) predict_proba = make_predict_fn(input_to_prediction) # Perturbs the input string, gets model predictions, fits linear model. explanation = lemon.explain( input_string, counterfactuals, predict_proba, class_to_explain=config['class_to_explain'], lowercase_tokens=config['lowercase_tokens']) scores = np.array(explanation.feature_importance) # Normalize feature values. scores = citrus_utils.normalize_scores(scores) result[text_key] = dtypes.TokenSalience(input_string.split(), scores) return [result]
def run_with_metadata(self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> List[JsonDict]: if model_outputs is None: model_outputs = list(model.predict_with_metadata(indexed_inputs)) # TODO(lit-team): pre-compute this mapping in constructor? # This would require passing a model name to this function so we can # reference a pre-computed list. spec = model.spec() field_map = map_pred_keys(dataset.spec(), spec.output, self.is_compatible) ret = [] for pred_key, label_key in field_map.items(): # Extract fields labels = [ex['data'][label_key] for ex in indexed_inputs] preds = [mo[pred_key] for mo in model_outputs] indices = [ex['id'] for ex in indexed_inputs] metas = [ex.get('meta', {}) for ex in indexed_inputs] # Compute metrics, as dict(str -> float) metrics = self.compute_with_metadata( labels, preds, label_spec=dataset.spec()[label_key], pred_spec=spec.output[pred_key], indices=indices, metas=metas, config=config.get(pred_key) if config else None) # NaN is not a valid JSON value, so replace with None which will be # serialized as null. # TODO(lit-team): move this logic into serialize.py somewhere instead? metrics = { k: (v if not np.isnan(v) else None) for k, v in metrics.items() } # Format for frontend. ret.append({ 'pred_key': pred_key, 'label_key': label_key, 'metrics': metrics }) return ret
def _get_embedding(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.IndexedDataset, embedding_name: str, dataset_name: str): """Calls the model on the example to get the embedding.""" model_input = dataset.index_inputs([example]) model_output = model.predict_with_metadata(model_input, dataset_name=dataset_name) embedding = list(model_output)[0][embedding_name] return embedding
def run_with_metadata( self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[JsonDict]: """Run this component, given a model and input(s). Args: indexed_inputs: Inputs to cluster. model: Model that provides salience maps. dataset: Dataset to compute salience maps for. model_outputs: Precomputed model outputs. config: Config for clustering and salience computation Returns: Dict with 2 keys: `CLUSTER_ID_KEY`: Contains the cluster assignments. One cluster id per dataset example. `REPRESENTATION_KEY`: Contains the representations of all examples in the dataset that were used in the clustering. """ config = config or {} # Find gradient fields to interpret grad_fields = self.find_fields(model.output_spec()) token_saliencies = self.salience_mappers[ config['salience_mapper']].run_with_metadata( indexed_inputs, model, dataset, model_outputs, config) if not token_saliencies: return None vocab = self._build_vocab(token_saliencies) representations = self._compute_fixed_length_representation( token_saliencies, vocab) cluster_ids = {} grad_field_to_representations = {} for grad_field in grad_fields: weight_matrix = np.vstack(representation[grad_field] for representation in representations) self.kmeans[grad_field] = cluster.KMeans(n_clusters=config.get( N_CLUSTERS_KEY, self.config_spec()[N_CLUSTERS_KEY].default)) cluster_ids[grad_field] = self.kmeans[grad_field].fit_predict( weight_matrix).tolist() grad_field_to_representations[grad_field] = weight_matrix return { CLUSTER_ID_KEY: cluster_ids, REPRESENTATION_KEY: grad_field_to_representations }
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" config = config or {} class_to_explain = config.get(CLASS_KEY, self.config_spec()[CLASS_KEY].default) interpolation_steps = int( config.get(INTERPOLATION_KEY, self.config_spec()[INTERPOLATION_KEY].default)) normalization = config.get( NORMALIZATION_KEY, self.config_spec()[NORMALIZATION_KEY].default) # Find gradient fields to interpret input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(input_spec, output_spec) logging.info('Found fields for integrated gradients: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) all_results = [] for model_output, model_input in zip(model_outputs, inputs): result = self.get_salience_result(model_input, model, interpolation_steps, normalization, class_to_explain, model_output, grad_fields) all_results.append(result) return all_results
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None): if model_outputs is None: model_outputs = list(model.predict(inputs)) spec = model.spec() field_map = map_pred_keys(dataset.spec(), spec.output, self.is_compatible) ret = [] for pred_key, label_key in field_map.items(): # Extract fields labels = [ex[label_key] for ex in inputs] preds = [mo[pred_key] for mo in model_outputs] # Compute metrics, as dict(str -> float) metrics = self.compute( labels, preds, label_spec=dataset.spec()[label_key], pred_spec=spec.output[pred_key], config=config.get(pred_key) if config else None) # NaN is not a valid JSON value, so replace with None which will be # serialized as null. # TODO(lit-team): move this logic into serialize.py somewhere instead? metrics = { k: (v if not np.isnan(v) else None) for k, v in metrics.items() } # Format for frontend. ret.append({ 'pred_key': pred_key, 'label_key': label_key, 'metrics': metrics }) return ret
def _filter_ds_examples( self, dataset: lit_dataset.IndexedDataset, dataset_name: Text, model: lit_model.Model, reference_output: JsonDict, pred_key: Text, regression_thresh: Optional[float] = None) -> List[JsonDict]: """Reads all dataset examples and returns only those that are flips.""" if not isinstance(dataset, lit_dataset.IndexedDataset): raise ValueError( 'Only indexed datasets are currently supported by the TabularMTC' 'generator.') indexed_examples = list(dataset.indexed_examples) filtered_examples = [] preds = model.predict_with_metadata(indexed_examples, dataset_name=dataset_name) # Find all DS examples that are flips with respect to the reference example. for indexed_example, pred in zip(indexed_examples, preds): flip = cf_utils.is_prediction_flip( cf_output=pred, orig_output=reference_output, output_spec=model.output_spec(), pred_key=pred_key, regression_thresh=regression_thresh) if flip: candidate_example = indexed_example['data'].copy() self._find_dataset_parent_and_set( model_output_spec=model.output_spec(), pred_key=pred_key, dataset_spec=dataset.spec(), example=candidate_example, predicted_value=pred[pred_key]) filtered_examples.append(candidate_example) return filtered_examples
def _run(self, model: lit_model.Model, indexed_inputs: List[JsonDict], model_outputs: Optional[List[JsonDict]] = None, do_fit=False): # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(indexed_inputs)) assert len(model_outputs) == len(indexed_inputs) converted_inputs = list( map(self.convert_input, indexed_inputs, model_outputs)) if do_fit: return self._projector.fit_transform_with_metadata( converted_inputs, dataset_name="") else: return self._projector.predict_with_metadata( converted_inputs, dataset_name="")
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None): # Get margin for each input for each pred key and add them to a config dict # to pass to the wrapped metrics. field_map = map_pred_keys(dataset.spec(), model.spec().output, self.is_compatible) margin_config = {} for pred_key in field_map: field_config = config.get(pred_key) if config else None margins = [ get_margin_for_input(field_config, inp) for inp in inputs ] margin_config[pred_key] = margins return self._metrics.run(inputs, model, dataset, model_outputs, margin_config)
def build(cls, inputs: List[JsonDict], encoder: lit_model.Model, edge_field: str, embs_field: str, offset_field: str, progress=lambda x: x, verbose=False): """Run encoder and extract span representations for coreference. 'encoder' should be a model returning one TokenEmbeddings field, from which span features will be extracted, as well as a TokenOffsets field which maps input tokens to output tokens. The returned dataset will contain one example for each label in the inputs' EdgeLabels field. Args: inputs: input Dataset encoder: encoder model, compatible with inputs edge_field: name of edge field in data embs_field: name of embeddings field in model output offset_field: name of offset field in model output progress: optional pass-through progress indicator verbose: if true, print estimated memory usage Returns: EdgeFeaturesDataset with extracted span representations """ examples = [] encoder_outputs = progress(encoder.predict(inputs)) for i, output in enumerate(encoder_outputs): exs = _make_probe_inputs( inputs[i][edge_field], output[embs_field], output[offset_field], src_idx=i) examples.extend(exs) if verbose and i == 10: _estimate_memory_needs(inputs, edge_field, examples[0]) return cls(examples)
def _generate_leave_one_out_ablation_score( self, example: JsonDict, model: lit_model.Model, input_spec: Spec, output_spec: Spec, orig_output: JsonDict, pred_key: Text, fields_to_ablate: List[str]) -> List[Tuple[str, int, float]]: # Returns a list of triples: field, token_idx and leave-one-out score. ret = [] for field in input_spec.keys(): if field not in example or field not in fields_to_ablate: continue tokens = self._get_tokens(example, input_spec, field) cfs = [ self._create_cf(example, input_spec, [(field, i)]) for i in range(len(tokens)) ] cf_outputs = model.predict(cfs) for i, cf_output in enumerate(cf_outputs): loo_score = cf_utils.prediction_difference( cf_output, orig_output, output_spec, pred_key) ret.append((field, i, loo_score)) return ret
def _train_instance(self, model: lit_model.Model, dataset: lit_dataset.IndexedDataset, config: JsonDict, name: Text) -> ProjectionInterpreter: # Ignore pytype warning about abstract methods, since this should always # be a subclass of ProjectorModel which has these implemented. projector = self._model_factory(**config.get("proj_kw", {})) # pytype: disable=not-instantiable train_inputs = dataset.indexed_examples # TODO(lit-dev): remove 'dataset_name' from caching logic so we don't need # to track it here or elsewhere. train_outputs = list( model.predict_with_metadata( train_inputs, dataset_name=config.get("dataset_name"))) logging.info("Creating new projection instance on %d points", len(train_inputs)) return ProjectionInterpreter(model, train_inputs, train_outputs, projector=projector, field_name=config["field_name"], name=name)
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: del dataset del config field_map = self.find_fields(model) # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) assert len(model_outputs) == len(inputs) return [ self._run_single(ex, mo, field_map) for ex, mo in zip(inputs, model_outputs) ]
def _train_instance(self, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Dict[Text, Any], name: Text) -> ProjectionInterpreter: # Ignore pytype warning about abstract methods, since this should always # be a subclass of ProjectorModel which has these implemented. projector = self._model_factory(**config.get("proj_kw", {})) # pytype: disable=not-instantiable # TODO(lit-dev): recomputing hashes here is a bit wasteful - consider # creating an 'IndexedDataset' class in the server, and passing that # around so that components can access IndexedInputs directly. train_inputs = caching.add_hashes_to_input(dataset.examples) # TODO(lit-dev): remove 'dataset_name' from caching logic so we don't need # to track it here or elsewhere. train_outputs = list( model.predict_with_metadata( train_inputs, dataset_name=config.get("dataset_name"))) logging.info("Creating new projection instance on %d points", len(train_inputs)) return ProjectionInterpreter( model, train_inputs, train_outputs, projector=projector, field_name=config["field_name"], name=name)
def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, model_output: JsonDict, grad_fields: List[Text]): result = {} output_spec = model.output_spec() # We ensure that the embedding and gradient class fields are present in the # model's input spec in find_fields(). embeddings_fields = [ cast(types.TokenGradients, output_spec[grad_field]).grad_for for grad_field in grad_fields ] # The gradient class input is used to specify the target class of the # gradient calculation (if unspecified, this option defaults to the argmax, # which could flip between interpolated inputs). grad_class_key = cast(types.TokenGradients, output_spec[grad_fields[0]]).grad_target # TODO(b/168042999): Add option to specify the class to explain in the UI. grad_class = model_output[grad_class_key] interpolated_inputs = {} all_embeddings = [] all_baselines = [] for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] embeddings = np.array(model_output[embed_field]) all_embeddings.append(embeddings) # Starts with baseline of zeros. <float32>[num_tokens, emb_size] baseline = self.get_baseline(embeddings) all_baselines.append(baseline) # Get interpolated inputs from baseline to original embedding. # <float32>[interpolation_steps, num_tokens, emb_size] interpolated_inputs[embed_field] = self.get_interpolated_inputs( baseline, embeddings, self.interpolation_steps) # Create model inputs and populate embedding field(s). inputs_with_embeds = [] for i in range(self.interpolation_steps): input_copy = model_input.copy() # Interpolates embeddings for all inputs simultaneously. for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] input_copy[embed_field] = interpolated_inputs[embed_field][i] input_copy[grad_class_key] = grad_class inputs_with_embeds.append(input_copy) embed_outputs = model.predict(inputs_with_embeds) # Create list with concatenated gradients for each interpolate input. gradients = [] for o in embed_outputs: # <float32>[total_num_tokens, emb_size] interp_gradients = np.concatenate( [o[field] for field in grad_fields]) gradients.append(interp_gradients) # <float32>[interpolation_steps, total_num_tokens, emb_size] path_gradients = np.stack(gradients, axis=0) # Calculate integral # <float32>[total_num_tokens, emb_size] integral = self.estimate_integral(path_gradients) # <float32>[total_num_tokens, emb_size] concat_embeddings = np.concatenate(all_embeddings) # <float32>[total_num_tokens, emb_size] concat_baseline = np.concatenate(all_baselines) # <float32>[total_num_tokens, emb_size] integrated_gradients = integral * (np.array(concat_embeddings) - np.array(concat_baseline)) # Dot product of integral values and (embeddings - baseline). # <float32>[total_num_tokens] attributions = np.sum(integrated_gradients, axis=-1) # TODO(b/168042999): Make normalization customizable in the UI. # <float32>[total_num_tokens] scores = citrus_utils.normalize_scores(attributions) for grad_field in grad_fields: # Format as salience map result. token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = model_output[token_field] # Only use the scores that correspond to the tokens in this grad_field. # The gradients for all input embeddings were concatenated in the order # of the grad fields, so they can be sliced out in the same order. sliced_scores = scores[:len( tokens)] # <float32>[num_tokens in field] scores = scores[len(tokens):] # <float32>[num_remaining_tokens] assert len(tokens) == len(sliced_scores) result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores) return result
def run(self, inputs: List[types.JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[types.JsonDict]] = None, config: Optional[types.JsonDict] = None): """Create PDP chart info using provided inputs. Args: inputs: sequence of inputs, following model.input_spec() model: optional model to use to generate new examples. dataset: dataset which the current examples belong to. model_outputs: optional precomputed model outputs config: optional runtime config. Returns: a dict of alternate feature values to model outputs. The model outputs will be a number for regression models and a list of numbers for multiclass models. """ pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('PDP did not find any supported output fields.') return None assert 'feature' in config, 'No feature to test provided' feature = config['feature'] provided_range = config['range'] if 'range' in config else [] edited_outputs = {} for pred_key in pred_keys: edited_outputs[pred_key] = {} # If a range was provided, use that to create the possible values. vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10) if len(provided_range) == 2 else self.get_vals_to_test( feature, dataset)) # If no specific inputs provided, use the entire dataset. inputs_to_use = inputs if inputs else dataset.examples # For each alternate value for a given feature. for new_val in vals_to_test: # Create copies of all provided inputs with the value replaced. edited_inputs = [] for inp in inputs_to_use: edited_input = copy.deepcopy(inp) edited_input[feature] = new_val edited_inputs.append(edited_input) # Run prediction on the altered inputs. outputs = list(model.predict(edited_inputs)) # Store the mean of the prediction for the alternate value. for pred_key in pred_keys: numeric = isinstance(model.output_spec()[pred_key], types.RegressionScore) if numeric: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs]) else: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs], axis=0) return edited_outputs
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT kernel_width = int( config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT mask_string = config[MASK_KEY] if config else MASK_DEFAULT num_samples = int( config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT seed = config[SEED_KEY] if config else SEED_DEFAULT # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial(_predict_fn, model=model, original_example=input_, pred_key=pred_key) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain= class_to_explain, # Index of the class to explain. num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def is_compatible(self, model: lit_model.Model): text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) return len(text_keys) and len(pred_keys)
def find_fields(self, model: lit_model.Model) -> Dict[str, List[str]]: src_fields = utils.find_spec_keys(model.input_spec(), types.TextSegment) gen_fields = utils.find_spec_keys( model.output_spec(), (types.GeneratedText, types.GeneratedTextCandidates)) return {f: src_fields for f in gen_fields}
def is_compatible(self, model: lit_model.Model) -> bool: input_spec = model.input_spec() output_spec = model.output_spec() return find_supported_fields(input_spec, output_spec) is not None
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Runs the component, given a model and input(s).""" input_spec = model.input_spec() output_spec = model.output_spec() if not inputs: return [] # Find all fields required for the interpretation. supported_fields = find_supported_fields(input_spec, output_spec) if supported_fields is None: return grad_field_key = supported_fields.grad_field_key image_field_key = supported_fields.image_field_key grad_target_field_key = supported_fields.grad_target_field_key preds_field_key = supported_fields.preds_field_key multiclass = grad_target_field_key is not None # Determine the shape of gradients by calling the model with a single input # and extracting the shape from the gradient output. model_output = list(model.predict([inputs[0]])) grad_shape = model_output[0][grad_field_key].shape # If it is a multiclass model, find the labels with respect to which we # should compute the gradients. if multiclass: # Get class labels. label_vocab = list( cast(types.MulticlassPreds, output_spec[preds_field_key]).vocab) # Run the model in order to find the gradient target labels. outputs = list(model.predict(inputs)) grad_target_labels = [] for model_input, model_output in zip(inputs, outputs): if model_input.get(grad_target_field_key) is not None: grad_target_labels.append( model_input[grad_target_field_key]) else: max_idx = int(np.argmax(model_output[preds_field_key])) grad_target_labels.append(label_vocab[max_idx]) else: grad_target_labels = [None] * len(inputs) saliency_object = self.get_saliency_object() extra_saliency_params = self.get_extra_saliency_params(config) all_results = [] for example, grad_target_label in zip(inputs, grad_target_labels): result = {} image_str = example[image_field_key] saliency_input = image_utils.convert_image_str_to_array( image_str=image_str, shape=grad_shape) call_model_func = get_call_model_func( model=model, model_input=example, image_field_key=image_field_key, grad_field_key=grad_field_key, grad_target_field_key=grad_target_field_key, grad_target_label=grad_target_label) attribution = self.make_saliency_call( saliency_object=saliency_object, x_value=saliency_input, call_model_function=call_model_func, extra_saliency_params=extra_saliency_params) if attribution.ndim == 3: attribution = attribution.sum(axis=2) viz_params = self.get_visualization_params() overlaid_image = image_utils.overlay_pixel_saliency( image_str, attribution, **viz_params) result[grad_field_key] = image_utils.convert_pil_to_image_str( overlaid_image) all_results.append(result) return all_results
def run_with_metadata( self, indexed_inputs: Sequence[IndexedInput], model: lit_model.Model, dataset: lit_dataset.IndexedDataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Runs the TCAV method given the params in the inputs and config. Args: indexed_inputs: all examples in the dataset, in the indexed input format. model: the model being explained. dataset: the dataset which the current examples belong to. model_outputs: optional model outputs from calling model.predict(inputs). config: a config which should specify: { 'concept_set_ids': [list of ids to use in concept set] 'class_to_explain': [gradient class to explain], 'grad_layer': [the Gradient field key of the layer to explain], 'random_state': [an optional seed to make outputs deterministic] 'dataset_name': [the name of the dataset (used for caching)] 'test_size': [Percentage of the example set to use in the LM test set] 'negative_set_ids': [optional list of ids to use as negative set] } Returns: A JsonDict containing the TCAV scores, directional derivatives, statistical test p-values, and LM accuracies. """ config = TCAVConfig(**config) # TODO(b/171513556): get these from the Dataset object once indices are # available there. dataset_examples = indexed_inputs # Get this layer's output spec keys for gradients and embeddings. grad_layer = config.grad_layer output_spec = model.output_spec() emb_layer = cast(types.Gradients, output_spec[grad_layer]).grad_for # Get the class that the gradients were computed for. grad_class_key = cast(types.Gradients, output_spec[grad_layer]).grad_target_field_key ids_set = set(config.concept_set_ids) concept_set = [ex for ex in dataset_examples if ex['id'] in ids_set] non_concept_set = [ ex for ex in dataset_examples if ex['id'] not in ids_set ] # Get outputs using model.predict(). dataset_outputs = list( model.predict_with_metadata(dataset_examples, dataset_name=config.dataset_name)) if config.negative_set_ids: negative_ids_set = set(config.negative_set_ids) negative_set = [ ex for ex in dataset_examples if ex['id'] in negative_ids_set ] return self._run_relative_tcav(grad_layer, emb_layer, grad_class_key, concept_set, negative_set, dataset_outputs, model, config) else: return self._run_default_tcav(grad_layer, emb_layer, grad_class_key, concept_set, non_concept_set, dataset_outputs, model, config)
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, kernel_width: int = 25, # TODO(lit-dev): make configurable in UI. mask_string: str = '[MASK]', # TODO(lit-dev): make configurable in UI. num_samples: int = 256, # TODO(lit-dev): make configurable in UI. ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys(model.output_spec(), types.MulticlassPreds) if not pred_keys: logging.warning( 'LIME did not find a multi-class predictions field.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key]) label_names = pred_spec.vocab # Create a LIME text explainer instance. explainer = lime_text.LimeTextExplainer( class_names=label_names, split_expression=str.split, kernel_width=kernel_width, mask_string=mask_string, # This is the string used to mask words. bow=False ) # bow=False masks inputs, instead of deleting them entirely. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] logging.info('Explaining: %s', input_string) # Use the number of words as the number of features. num_features = len(input_string.split()) def _predict_proba(strings: List[Text]): """Given raw strings, return probabilities. Used by `explainer`.""" input_examples = [ new_example(input_, text_key, s) for s in strings ] model_outputs = model.predict(input_examples) probs = np.array( [output[pred_key] for output in model_outputs]) return probs # <float32>[len(strings), num_labels] # Perturbs the input string, gets model predictions, fits linear model. explanation = explainer.explain_instance( input_string, _predict_proba, num_features=num_features, num_samples=num_samples) # Turn the LIME explanation into a list following original word order. scores = explanation_to_array(explanation) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None, num_examples: int = 1) -> List[JsonDict]: """Use gradient to find/substitute the token with largest impact on loss.""" del dataset # Unused. assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) # Find gradient fields to use for HotFlip input_spec = model.input_spec() output_spec = model.output_spec() grad_fields = self.find_fields(output_spec) logging.info("Found gradient fields for HotFlip use: %s", str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test logging.info("No gradient fields found. Cannot use HotFlip. :-(") return [] # Cannot generate examples without gradients. # Get model outputs. logging.info( "Performing a forward/backward pass on the input example.") model_output = model.predict_single(example) logging.info(model_output.keys()) # Get model word embeddings and vocab. inv_vocab, embed = model.get_embedding_table() assert len( inv_vocab) == embed.shape[0], "Vocab/embeddings size mismatch." logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab), embed.shape) # Perform a flip in each sequence for which we have gradients (separately). # Each sequence may give rise to multiple new examples, depending on how # many words we flip. # TODO(lit-team): make configurable how many new examples are desired. # TODO(lit-team): use only 1 sequence as input (configurable in UI). new_examples = [] for grad_field in grad_fields: # Get the tokens and their gradient vectors. token_field = output_spec[grad_field].align # pytype: disable=attribute-error tokens = model_output[token_field] grads = model_output[grad_field] # Identify the token with the largest gradient norm. # TODO(lit-team): consider normalizing across all grad fields or just # across each one individually. grad_norm = np.linalg.norm(grads, axis=1) grad_norm = grad_norm / np.sum( grad_norm) # Match grad attribution value. # Get a list of indices of input tokens, sorted by norm, highest first. sorted_by_grad_norm = np.argsort(grad_norm)[::-1] for i in range(min(num_examples, len(tokens))): token_id = sorted_by_grad_norm[i] logging.info( "Selected token: %s (pos=%d) with gradient norm %f", tokens[token_id], token_id, grad_norm[token_id]) token_grad = grads[token_id] # Take dot product with all word embeddings. Get largest value. scores = np.dot(embed, token_grad) # TODO(lit-team): Can add criteria to the winner e.g. cosine distance. winner = np.argmax(scores) logging.info( "Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)", tokens[token_id], token_id, i, inv_vocab[winner], winner) # Create a new input to the model. # TODO(iftenney, bastings): enforce somewhere that this field has the # same name in the input and output specs. input_token_field = token_field input_text_field = input_spec[input_token_field].parent # pytype: disable=attribute-error new_example = copy.deepcopy(example) modified_tokens = copy.copy(tokens) modified_tokens[token_id] = inv_vocab[winner] new_example[input_token_field] = modified_tokens # TODO(iftenney, bastings): call a model-provided detokenizer here? # Though in general tokenization isn't invertible and it's possible for # HotFlip to produce wordpiece sequences that don't correspond to any # input string. new_example[input_text_field] = " ".join(modified_tokens) # Predict a new label for this example. new_output = model.predict_single(new_example) # Update label if multi-class prediction. # TODO(lit-dev): provide a general system for handling labels on # generated examples. for pred_key, pred_type in model.output_spec().items(): if isinstance(pred_type, types.MulticlassPreds): probabilities = new_output[pred_key] prediction = np.argmax(probabilities) label_key = output_spec[pred_key].parent label_names = output_spec[pred_key].vocab new_label = label_names[prediction] new_example[label_key] = new_label logging.info("Updated example with new label: %s", new_label) new_examples.append(new_example) return new_examples