def test_explain_regression(self, sentence, num_samples, positive_token, negative_token): """Tests explaining text classifiers with various output dimensions.""" def _predict_fn(sentences): """Mock prediction function.""" rs = np.random.RandomState(seed=0) predictions = [] for sentence in sentences: output = rs.uniform(-2., 2.) # To check if LIME finds the right positive/negative correlations. if negative_token in sentence: output -= rs.uniform(0., 2.) if positive_token in sentence: output += rs.uniform(0., 2.) predictions.append(output) predictions = np.stack(predictions, axis=0) return predictions explanation = lime.explain(sentence, _predict_fn, num_samples=num_samples, tokenizer=str.split) self.assertLen(explanation.feature_importance, len(sentence.split())) # The positive word should have the highest attribution score. positive_token_idx = sentence.split().index(positive_token) self.assertEqual(positive_token_idx, np.argmax(explanation.feature_importance)) # The negative word should have the lowest attribution score. negative_token_idx = sentence.split().index(negative_token) self.assertEqual(negative_token_idx, np.argmin(explanation.feature_importance))
def test_explain_returns_explanation_with_intercept(self): """Tests if the explanation contains an intercept value.""" def _predict_fn(sentences): return np.random.uniform(0., 1., [len(list(sentences)), 2]) explanation = lime.explain('Test sentence', _predict_fn, 1, num_samples=5) self.assertNotEqual(explanation.intercept, 0.)
def test_explain_returns_explanation_with_prediction(self): """Tests if the explanation contains a prediction.""" def _predict_fn(sentences): return np.random.uniform(0., 1., [len(list(sentences)), 2]) explanation = lime.explain('Test sentence', _predict_fn, class_to_explain=1, num_samples=5, return_prediction=True) self.assertIsNotNone(explanation.prediction)
def test_explain(self, sentence, num_samples, positive_token, negative_token, num_classes, class_to_explain): """Tests explaining text classifiers with various output dimensions.""" def _predict_fn(sentences): """Mock prediction function.""" rs = np.random.RandomState(seed=0) predictions = [] for sentence in sentences: probs = rs.uniform(0., 1., num_classes) # To check if LIME finds the right positive/negative correlations. if negative_token in sentence: probs[class_to_explain] = probs[class_to_explain] - 1. if positive_token in sentence: probs[class_to_explain] = probs[class_to_explain] + 1. predictions.append(probs) predictions = np.stack(predictions, axis=0) if num_classes == 1: return np.squeeze(special.expit(predictions), -1) else: return special.softmax(predictions, axis=-1) explanation = lime.explain(sentence, _predict_fn, class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split) self.assertLen(explanation.feature_importance, len(sentence.split())) # The positive word should have the highest attribution score. positive_token_idx = sentence.split().index(positive_token) self.assertEqual(positive_token_idx, np.argmax(explanation.feature_importance)) # The negative word should have the lowest attribution score. negative_token_idx = sentence.split().index(negative_token) self.assertEqual(negative_token_idx, np.argmin(explanation.feature_importance))
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT kernel_width = int( config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT mask_string = config[MASK_KEY] if config else MASK_DEFAULT num_samples = int( config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT seed = config[SEED_KEY] if config else SEED_DEFAULT # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial(_predict_fn, model=model, original_example=input_, pred_key=pred_key) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain= class_to_explain, # Index of the class to explain. num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" config_defaults = {k: v.default for k, v in self.config_spec().items()} config = dict(config_defaults, **(config or {})) # update and return provided_class_to_explain = int(config[CLASS_KEY]) kernel_width = int(config[KERNEL_WIDTH_KEY]) num_samples = int(config[NUM_SAMPLES_KEY]) mask_string = (config[MASK_KEY]) # pylint: disable=g-explicit-bool-comparison seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None # pylint: enable=g-explicit-bool-comparison # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore, types.SparseMultilabelPreds)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = config[TARGET_HEAD_KEY] or pred_keys[0] all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial( _predict_fn, model=model, original_example=input_, pred_key=pred_key, pred_type_info=model.output_spec()[pred_key]) class_to_explain = get_class_to_explain(provided_class_to_explain, model, pred_key, input_) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.TokenSalience( input_string.split(), scores) all_results.append(result) return all_results
def test_explain_matches_original_lime(self, sentence, num_samples, num_classes, class_to_explain): """Tests if Citrus LIME matches the original implementation.""" list('abcdefghijklmnopqrstuvwxyz') # Assign some weight to each token a-z. # Each token contributes positively/negatively to the prediction. rs = np.random.RandomState(seed=0) token_weights = {token: rs.normal() for token in sentence.split()} token_weights[lime.DEFAULT_MASK_TOKEN] = 0. def _predict_fn(sentences): """Mock prediction function.""" rs = np.random.RandomState(seed=0) predictions = [] for sentence in sentences: probs = rs.normal(0., 0.1, size=num_classes) # To check if LIME finds the right positive/negative correlations. for token in sentence.split(): probs[class_to_explain] += token_weights[token] predictions.append(probs) return np.stack(predictions, axis=0) # Explain the prediction using Citrus LIME. explanation = lime.explain(sentence, _predict_fn, class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split, mask_token=lime.DEFAULT_MASK_TOKEN, kernel=functools.partial( lime.exponential_kernel, kernel_width=lime.DEFAULT_KERNEL_WIDTH)) scores = explanation.feature_importance # <float32>[seq_len] scores = utils.normalize_scores(scores, make_positive=False) # Explain the prediction using original LIME. original_lime_explainer = original_lime.LimeTextExplainer( class_names=map(str, np.arange(num_classes)), mask_string=lime.DEFAULT_MASK_TOKEN, kernel_width=lime.DEFAULT_KERNEL_WIDTH, split_expression=str.split, bow=False) num_features = len(sentence.split()) original_explanation = original_lime_explainer.explain_instance( sentence, _predict_fn, labels=(class_to_explain, ), num_features=num_features, num_samples=num_samples) # original_explanation.local_exp is a dict that has a key class_to_explain, # which gives a sequence of (index, score) pairs. # We convert it to an array <float32>[seq_len] with a score per position. original_scores = np.zeros(num_features) for index, score in original_explanation.local_exp[class_to_explain]: original_scores[index] = score original_scores = utils.normalize_scores(original_scores, make_positive=False) # Test that Citrus LIME and original LIME match. np.testing.assert_allclose(scores, original_scores, atol=0.01)