def test_normalize_scores(self, scores, make_positive): """Check if the scores sum to 1 after taking their absolute values.""" original_min = np.min(scores) scores = utils.normalize_scores(scores, make_positive=make_positive) self.assertAllClose(1.0, np.abs(scores).sum(-1)) if not make_positive: # Keep the sign of originally negative values. self.assertLessEqual(np.min(scores), np.max([0.0, original_min]))
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" if not inputs: return # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LEMON requires text inputs.') return None logging.info('Found text fields for LEMON attribution: %s', str(text_keys)) pred_key = config['pred_key'] output_probs = np.array([output[pred_key] for output in model_outputs]) # Explain the input given counterfactuals. # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: sentences = [item[text_key] for item in inputs] input_to_prediction = dict(zip(sentences, output_probs)) input_string = sentences[0] counterfactuals = sentences[1:] # Remove duplicate counterfactuals. counterfactuals = list(set(counterfactuals)) logging.info('Explaining: %s', input_string) predict_proba = make_predict_fn(input_to_prediction) # Perturbs the input string, gets model predictions, fits linear model. explanation = lemon.explain( input_string, counterfactuals, predict_proba, class_to_explain=config['class_to_explain'], lowercase_tokens=config['lowercase_tokens']) scores = np.array(explanation.feature_importance) # Normalize feature values. scores = citrus_utils.normalize_scores(scores) result[text_key] = dtypes.TokenSalience(input_string.split(), scores) return [result]
def _interpret(self, grads: np.ndarray, embs: np.ndarray): assert grads.shape == embs.shape # dot product of gradients and embeddings # <float32>[num_tokens] grad_dot_input = np.sum(grads * embs, axis=-1) scores = citrus_utils.normalize_scores(grad_dot_input) return scores
def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, model_output: JsonDict, grad_fields: List[Text]): result = {} output_spec = model.output_spec() # We ensure that the embedding and gradient class fields are present in the # model's input spec in find_fields(). embeddings_fields = [ cast(types.TokenGradients, output_spec[grad_field]).grad_for for grad_field in grad_fields ] # The gradient class input is used to specify the target class of the # gradient calculation (if unspecified, this option defaults to the argmax, # which could flip between interpolated inputs). grad_class_key = cast(types.TokenGradients, output_spec[grad_fields[0]]).grad_target # TODO(b/168042999): Add option to specify the class to explain in the UI. grad_class = model_output[grad_class_key] interpolated_inputs = {} all_embeddings = [] all_baselines = [] for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] embeddings = np.array(model_output[embed_field]) all_embeddings.append(embeddings) # Starts with baseline of zeros. <float32>[num_tokens, emb_size] baseline = self.get_baseline(embeddings) all_baselines.append(baseline) # Get interpolated inputs from baseline to original embedding. # <float32>[interpolation_steps, num_tokens, emb_size] interpolated_inputs[embed_field] = self.get_interpolated_inputs( baseline, embeddings, self.interpolation_steps) # Create model inputs and populate embedding field(s). inputs_with_embeds = [] for i in range(self.interpolation_steps): input_copy = model_input.copy() # Interpolates embeddings for all inputs simultaneously. for embed_field in embeddings_fields: # <float32>[num_tokens, emb_size] input_copy[embed_field] = interpolated_inputs[embed_field][i] input_copy[grad_class_key] = grad_class inputs_with_embeds.append(input_copy) embed_outputs = model.predict(inputs_with_embeds) # Create list with concatenated gradients for each interpolate input. gradients = [] for o in embed_outputs: # <float32>[total_num_tokens, emb_size] interp_gradients = np.concatenate( [o[field] for field in grad_fields]) gradients.append(interp_gradients) # <float32>[interpolation_steps, total_num_tokens, emb_size] path_gradients = np.stack(gradients, axis=0) # Calculate integral # <float32>[total_num_tokens, emb_size] integral = self.estimate_integral(path_gradients) # <float32>[total_num_tokens, emb_size] concat_embeddings = np.concatenate(all_embeddings) # <float32>[total_num_tokens, emb_size] concat_baseline = np.concatenate(all_baselines) # <float32>[total_num_tokens, emb_size] integrated_gradients = integral * (np.array(concat_embeddings) - np.array(concat_baseline)) # Dot product of integral values and (embeddings - baseline). # <float32>[total_num_tokens] attributions = np.sum(integrated_gradients, axis=-1) # TODO(b/168042999): Make normalization customizable in the UI. # <float32>[total_num_tokens] scores = citrus_utils.normalize_scores(attributions) for grad_field in grad_fields: # Format as salience map result. token_field = cast(types.TokenGradients, output_spec[grad_field]).align tokens = model_output[token_field] # Only use the scores that correspond to the tokens in this grad_field. # The gradients for all input embeddings were concatenated in the order # of the grad fields, so they can be sliced out in the same order. sliced_scores = scores[:len( tokens)] # <float32>[num_tokens in field] scores = scores[len(tokens):] # <float32>[num_remaining_tokens] assert len(tokens) == len(sliced_scores) result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores) return result
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT kernel_width = int( config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT mask_string = config[MASK_KEY] if config else MASK_DEFAULT num_samples = int( config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT seed = config[SEED_KEY] if config else SEED_DEFAULT # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial(_predict_fn, model=model, original_example=input_, pred_key=pred_key) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain= class_to_explain, # Index of the class to explain. num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" config_defaults = {k: v.default for k, v in self.config_spec().items()} config = dict(config_defaults, **(config or {})) # update and return provided_class_to_explain = int(config[CLASS_KEY]) kernel_width = int(config[KERNEL_WIDTH_KEY]) num_samples = int(config[NUM_SAMPLES_KEY]) mask_string = (config[MASK_KEY]) # pylint: disable=g-explicit-bool-comparison seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None # pylint: enable=g-explicit-bool-comparison # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore, types.SparseMultilabelPreds)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = config[TARGET_HEAD_KEY] or pred_keys[0] all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial( _predict_fn, model=model, original_example=input_, pred_key=pred_key, pred_type_info=model.output_spec()[pred_key]) class_to_explain = get_class_to_explain(provided_class_to_explain, model, pred_key, input_) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.TokenSalience( input_string.split(), scores) all_results.append(result) return all_results
def test_explain_matches_original_lime(self, sentence, num_samples, num_classes, class_to_explain): """Tests if Citrus LIME matches the original implementation.""" list('abcdefghijklmnopqrstuvwxyz') # Assign some weight to each token a-z. # Each token contributes positively/negatively to the prediction. rs = np.random.RandomState(seed=0) token_weights = {token: rs.normal() for token in sentence.split()} token_weights[lime.DEFAULT_MASK_TOKEN] = 0. def _predict_fn(sentences): """Mock prediction function.""" rs = np.random.RandomState(seed=0) predictions = [] for sentence in sentences: probs = rs.normal(0., 0.1, size=num_classes) # To check if LIME finds the right positive/negative correlations. for token in sentence.split(): probs[class_to_explain] += token_weights[token] predictions.append(probs) return np.stack(predictions, axis=0) # Explain the prediction using Citrus LIME. explanation = lime.explain(sentence, _predict_fn, class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split, mask_token=lime.DEFAULT_MASK_TOKEN, kernel=functools.partial( lime.exponential_kernel, kernel_width=lime.DEFAULT_KERNEL_WIDTH)) scores = explanation.feature_importance # <float32>[seq_len] scores = utils.normalize_scores(scores, make_positive=False) # Explain the prediction using original LIME. original_lime_explainer = original_lime.LimeTextExplainer( class_names=map(str, np.arange(num_classes)), mask_string=lime.DEFAULT_MASK_TOKEN, kernel_width=lime.DEFAULT_KERNEL_WIDTH, split_expression=str.split, bow=False) num_features = len(sentence.split()) original_explanation = original_lime_explainer.explain_instance( sentence, _predict_fn, labels=(class_to_explain, ), num_features=num_features, num_samples=num_samples) # original_explanation.local_exp is a dict that has a key class_to_explain, # which gives a sequence of (index, score) pairs. # We convert it to an array <float32>[seq_len] with a score per position. original_scores = np.zeros(num_features) for index, score in original_explanation.local_exp[class_to_explain]: original_scores[index] = score original_scores = utils.normalize_scores(original_scores, make_positive=False) # Test that Citrus LIME and original LIME match. np.testing.assert_allclose(scores, original_scores, atol=0.01)