def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find gradient fields to interpret output_spec = model.output_spec() grad_fields = self.find_fields(output_spec) logging.info('Found fields for gradient attribution: %s', str(grad_fields)) if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test return None # Run model, if needed. if model_outputs is None: model_outputs = list(model.predict(inputs)) assert len(model_outputs) == len(inputs) all_results = [] for o in model_outputs: # Dict[field name -> interpretations] result = {} for grad_field in grad_fields: token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = o[token_field] scores = self._interpret(o[grad_field], tokens) result[grad_field] = dtypes.SalienceMap(tokens, scores) all_results.append(result) return all_results
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" if not inputs: return # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LEMON requires text inputs.') return None logging.info('Found text fields for LEMON attribution: %s', str(text_keys)) pred_key = config['pred_key'] output_probs = np.array([output[pred_key] for output in model_outputs]) # Explain the input given counterfactuals. # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: sentences = [item[text_key] for item in inputs] input_to_prediction = dict(zip(sentences, output_probs)) input_string = sentences[0] counterfactuals = sentences[1:] # Remove duplicate counterfactuals. counterfactuals = list(set(counterfactuals)) logging.info('Explaining: %s', input_string) predict_proba = make_predict_fn(input_to_prediction) # Perturbs the input string, gets model predictions, fits linear model. explanation = lemon.explain( input_string, counterfactuals, predict_proba, class_to_explain=config['class_to_explain'], lowercase_tokens=config['lowercase_tokens']) scores = np.array(explanation.feature_importance) # Normalize feature values. scores = citrus_utils.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) return [result]
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, kernel_width: int = 25, # TODO(lit-dev): make configurable in UI. mask_string: str = '[MASK]', # TODO(lit-dev): make configurable in UI. num_samples: int = 256, # TODO(lit-dev): make configurable in UI. ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys(model.output_spec(), types.MulticlassPreds) if not pred_keys: logging.warning( 'LIME did not find a multi-class predictions field.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key]) label_names = pred_spec.vocab # Create a LIME text explainer instance. explainer = lime_text.LimeTextExplainer( class_names=label_names, split_expression=str.split, kernel_width=kernel_width, mask_string=mask_string, # This is the string used to mask words. bow=False ) # bow=False masks inputs, instead of deleting them entirely. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] logging.info('Explaining: %s', input_string) # Use the number of words as the number of features. num_features = len(input_string.split()) def _predict_proba(strings: List[Text]): """Given raw strings, return probabilities. Used by `explainer`.""" input_examples = [ new_example(input_, text_key, s) for s in strings ] model_outputs = model.predict(input_examples) probs = np.array( [output[pred_key] for output in model_outputs]) return probs # <float32>[len(strings), num_labels] # Perturbs the input string, gets model predictions, fits linear model. explanation = explainer.explain_instance( input_string, _predict_proba, num_features=num_features, num_samples=num_samples) # Turn the LIME explanation into a list following original word order. scores = explanation_to_array(explanation) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, model_output: JsonDict, grad_fields: List[Text]): result = {} output_spec = model.output_spec() # We ensure that the embedding and gradient class fields are present in the # model's input spec in find_fields(). embeddings_fields = [ cast(types.TokenGradients, output_spec[grad_field]).grad_for for grad_field in grad_fields ] # The gradient class input is used to specify the target class of the # gradient calculation (if unspecified, this option defaults to the argmax, # which could flip between interpolated inputs). grad_class_key = cast(types.TokenGradients, output_spec[grad_fields[0]]).grad_target # TODO(b/168042999): Add option to specify the class to explain in the UI. grad_class = model_output[grad_class_key] interpolated_inputs = {} all_embeddings = [] all_baselines = [] for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] embeddings = np.array(model_output[embed_field]) all_embeddings.append(embeddings) # Starts with baseline of zeros. <float32>[num_tokens, emb_size] baseline = self.get_baseline(embeddings) all_baselines.append(baseline) # Get interpolated inputs from baseline to original embedding. # <float32>[interpolation_steps, num_tokens, emb_size] interpolated_inputs[embed_field] = self.get_interpolated_inputs( baseline, embeddings, self.interpolation_steps) # Create model inputs and populate embedding field(s). inputs_with_embeds = [] for i in range(self.interpolation_steps): input_copy = model_input.copy() # Interpolates embeddings for all inputs simultaneously. for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] input_copy[embed_field] = interpolated_inputs[embed_field][i] input_copy[grad_class_key] = grad_class inputs_with_embeds.append(input_copy) embed_outputs = model.predict(inputs_with_embeds) # Create list with concatenated gradients for each interpolate input. gradients = [] for o in embed_outputs: # <float32>[total_num_tokens, emb_size] interp_gradients = np.concatenate( [o[field] for field in grad_fields]) gradients.append(interp_gradients) # <float32>[interpolation_steps, total_num_tokens, emb_size] path_gradients = np.stack(gradients, axis=0) # Calculate integral # <float32>[total_num_tokens, emb_size] integral = self.estimate_integral(path_gradients) # <float32>[total_num_tokens, emb_size] concat_embeddings = np.concatenate(all_embeddings) # <float32>[total_num_tokens, emb_size] concat_baseline = np.concatenate(all_baselines) # <float32>[total_num_tokens, emb_size] integrated_gradients = integral * (np.array(concat_embeddings) - np.array(concat_baseline)) # Dot product of integral values and (embeddings - baseline). # <float32>[total_num_tokens] attributions = np.sum(integrated_gradients, axis=-1) # TODO(b/168042999): Make normalization customizable in the UI. # <float32>[total_num_tokens] scores = citrus_utils.normalize_scores(attributions) for grad_field in grad_fields: # Format as salience map result. token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = model_output[token_field] # Only use the scores that correspond to the tokens in this grad_field. # The gradients for all input embeddings were concatenated in the order # of the grad fields, so they can be sliced out in the same order. sliced_scores = scores[:len( tokens)] # <float32>[num_tokens in field] scores = scores[len(tokens):] # <float32>[num_remaining_tokens] assert len(tokens) == len(sliced_scores) result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores) return result
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT kernel_width = int( config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT mask_string = config[MASK_KEY] if config else MASK_DEFAULT num_samples = int( config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT seed = config[SEED_KEY] if config else SEED_DEFAULT # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial(_predict_fn, model=model, original_example=input_, pred_key=pred_key) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain= class_to_explain, # Index of the class to explain. num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results