def __init__(self, encoder: encoders.BertEncoderWithOffsets, classifier: edge_predictor.SingleEdgePredictor): self.encoder = encoder self.classifier = classifier embs_field = utils.find_spec_keys(self.encoder.output_spec(), lit_types.TokenEmbeddings)[0] offset_field = utils.find_spec_keys(self.encoder.output_spec(), lit_types.SubwordOffsets)[0] self.extractor_kw = dict(edge_field='coref', embs_field=embs_field, offset_field=offset_field)
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Replace words based on replacement list.""" del model # Unused. subs_string = config.get('subs') if config else None if subs_string: replacements = self.parse_subs_string(subs_string) else: replacements = self.default_replacements new_examples = [] # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) for text_key in text_keys: text_data = example[text_key] token_spans = map(lambda x: x.span(), self.tokenization_pattern.finditer(text_data)) for new_val in self.generate_counterfactuals( text_data, token_spans, replacements): new_example = copy.deepcopy(example) new_example[text_key] = new_val new_examples.append(new_example) return new_examples
def output_spec(self): ret = glue_models.STSBModel.output_spec(self) token_gradient_keys = utils.find_spec_keys(ret, lit_types.TokenGradients) for k in token_gradient_keys: ret.pop(k, None) return ret
def find_fields(self, input_spec: Spec, output_spec: Spec) -> List[Text]: # Find TokenGradients fields grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients) # Check that these are aligned to Tokens fields aligned_fields = [] for f in grad_fields: tokens_field = output_spec[f].align # pytype: disable=attribute-error assert tokens_field in output_spec assert isinstance(output_spec[tokens_field], types.Tokens) embeddings_field = output_spec[f].grad_for grad_class_key = output_spec[f].grad_target if embeddings_field is not None and grad_class_key is not None: assert embeddings_field in input_spec assert isinstance(input_spec[embeddings_field], types.TokenEmbeddings) assert embeddings_field in output_spec assert isinstance(output_spec[embeddings_field], types.TokenEmbeddings) assert grad_class_key in input_spec assert grad_class_key in output_spec aligned_fields.append(f) else: logging.info('Skipping %s since embeddings field not found.', str(f)) return aligned_fields
def run(self, inputs: List[JsonDict], dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None): """Run generation on a set of inputs. Args: inputs: sequence of inputs, following dataset.spec() dataset: dataset, used to access dataset.spec() config: additional runtime options Returns: list of list of new generated inputs, following dataset.spec() """ all_outputs = [[] for _ in inputs] # Find text fields text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment) # TODO(lit-team): configure a subset of fields to operate on candidates_by_field = {} for field_name in text_fields: texts = [ex[field_name] for ex in inputs] candidates_by_field[field_name] = self.generate_from_texts(texts) # Generate by substituting in each field. # TODO(lit-team): substitute on a combination of fields? for field_name in candidates_by_field: candidates = candidates_by_field[field_name] for i, ex in enumerate(inputs): for candidate in candidates[i]: new_ex = utils.copy_and_update(ex, {field_name: candidate}) all_outputs[i].append(new_ex) return all_outputs
def annotate(self, inputs: List[JsonDict], dataset: lit_dataset.Dataset, dataset_spec_to_annotate: Optional[types.Spec] = None): if len(self._annotator_model.input_spec().items()) != 1: raise ValueError( 'Annotator model provided to PerFieldAnnotator does not ' 'operate on a single field') datasets = {} for input_name, input_type in self._annotator_model.input_spec().items( ): # Do remap of inputs based on input name needed by annotator. ds_keys = utils.find_spec_keys(dataset.spec(), type(input_type)) for ds_key in ds_keys: temp_ds = lit_dataset.Dataset(examples=inputs, base=dataset) datasets[ds_key] = temp_ds.remap({ds_key: input_name}) for ds_key, ds in datasets.items(): outputs = self._annotator_model.predict(ds.examples) for output_name, output_type in self._annotator_model.output_spec( ).items(): # Update dataset spec with new annotated field. field_name = f'{self._name}:{output_name}:{ds_key}' if dataset_spec_to_annotate: dataset_spec_to_annotate[field_name] = attr.evolve( output_type, annotated=True) # Update all examples with annotator output. for example, output in zip(inputs, outputs): example[field_name] = output[output_name]
def _get_tokens_and_gradients(self, input_spec: JsonDict, output_spec: JsonDict, output: JsonDict, selected_fields: List[str]): """Returns a dictionary mapping token fields to tokens and gradients.""" # Find selected token fields. input_spec_keys = set(utils.find_spec_keys(input_spec, types.Tokens)) logging.info("input_spec_keys: %r", input_spec_keys) selected_input_spec_keys = list(input_spec_keys & set(selected_fields)) logging.info("selected_input_spec_keys: %r", selected_input_spec_keys) token_fields = [ key for key in selected_input_spec_keys if input_spec[key].is_compatible(output_spec.get(key)) ] if len(token_fields) == 0: # pylint: disable=g-explicit-length-test return {} ret = {} for token_field in token_fields: # Get tokens, token gradients and token embeddings. tokens = output[token_field] grad_fields = self.find_fields(output_spec, types.TokenGradients, token_field) assert grad_fields, ( f"No gradients found for {token_field}. Cannot use HotFlip. :-(" ) assert len(grad_fields) == 1, ( f"Multiple gradients found for {token_field}." f"Cannot use HotFlip. :-(") grads = output[grad_fields[0]] if grad_fields else None ret[token_field] = [tokens, grads] return ret
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" if not inputs: return # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LEMON requires text inputs.') return None logging.info('Found text fields for LEMON attribution: %s', str(text_keys)) pred_key = config['pred_key'] output_probs = np.array([output[pred_key] for output in model_outputs]) # Explain the input given counterfactuals. # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: sentences = [item[text_key] for item in inputs] input_to_prediction = dict(zip(sentences, output_probs)) input_string = sentences[0] counterfactuals = sentences[1:] # Remove duplicate counterfactuals. counterfactuals = list(set(counterfactuals)) logging.info('Explaining: %s', input_string) predict_proba = make_predict_fn(input_to_prediction) # Perturbs the input string, gets model predictions, fits linear model. explanation = lemon.explain( input_string, counterfactuals, predict_proba, class_to_explain=config['class_to_explain'], lowercase_tokens=config['lowercase_tokens']) scores = np.array(explanation.feature_importance) # Normalize feature values. scores = citrus_utils.normalize_scores(scores) result[text_key] = dtypes.TokenSalience(input_string.split(), scores) return [result]
def test_find_spec_keys(self): spec = { "score": types.RegressionScore(), "scalar_foo": types.Scalar(), "text": types.TextSegment(), "emb_0": types.Embeddings(), "emb_1": types.Embeddings(), "tokens": types.Tokens(), "generated_text": types.GeneratedText(), } self.assertEqual(["score"], utils.find_spec_keys(spec, types.RegressionScore)) self.assertEqual(["text", "tokens", "generated_text"], utils.find_spec_keys(spec, (types.TextSegment, types.Tokens))) self.assertEqual(["emb_0", "emb_1"], utils.find_spec_keys(spec, types.Embeddings)) self.assertEqual([], utils.find_spec_keys(spec, types.AttentionHeads)) # Check subclasses self.assertEqual( list(spec.keys()), utils.find_spec_keys(spec, types.LitType)) self.assertEqual(["text", "generated_text"], utils.find_spec_keys(spec, types.TextSegment)) self.assertEqual(["score", "scalar_foo"], utils.find_spec_keys(spec, types.Scalar))
def find_supported_fields(input_spec: Spec, output_spec: Spec) -> Optional[SupportedFields]: """Returns fields from the model specs that are needed for saliency .""" # Find all ImageGradients fields. grad_field_keys = lit_utils.find_spec_keys(output_spec, types.ImageGradients) # Models with more than one gradient field are not supported. if len(grad_field_keys) > 1 or not grad_field_keys: return None grad_field_key = grad_field_keys[0] grad_field_value = cast(types.ImageGradients, output_spec[grad_field_key]) # Find image fields that correspond to grad_field. image_field_key = grad_field_value.align assert isinstance(input_spec[image_field_key], types.ImageBytes) # Find gradient target fields in the input if it is a multiclass # classification model. The value of None means that it is a regression or # single class classification model. multiclass = grad_field_value.grad_target_field_key is not None if multiclass: grad_target_field_key = grad_field_value.grad_target_field_key assert isinstance(input_spec[grad_target_field_key], types.CategoryLabel) else: grad_target_field_key = None # Find prediction field names. if multiclass: preds_field_keys = lit_utils.find_spec_keys(output_spec, types.MulticlassPreds) else: preds_field_keys = lit_utils.find_spec_keys(output_spec, types.RegressionScore) # Models with more than one prediction field are not supported. if len(preds_field_keys) > 1 or not preds_field_keys: return None preds_field_key = preds_field_keys[0] return SupportedFields(grad_field_key=grad_field_key, image_field_key=image_field_key, grad_target_field_key=grad_target_field_key, preds_field_key=preds_field_key)
def find_fields(self, output_spec: Spec) -> List[Text]: # Find TokenGradients fields grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients) # Check that these are aligned to Tokens fields for f in grad_fields: tokens_field = output_spec[f].align # pytype: disable=attribute-error assert tokens_field in output_spec assert isinstance(output_spec[tokens_field], types.Tokens) return grad_fields
def _fill_indices(self, model_name, dataset_name): """Create all indices for a single model.""" model = self._models.get(model_name) assert model is not None, "Invalid model name." examples = self.datasets[dataset_name].indexed_examples model_embeddings_names = utils.find_spec_keys(model.output_spec(), lit_types.Embeddings) lookup_key = self._get_lookup_key(model_name, dataset_name) # If the model has no embeddings to extract, skip the following. if not model_embeddings_names: return # Load from file if it exists. for emb_name in model_embeddings_names: # Initialize the index object in self._indices with serialized index. self._init_index_from_file(model_name, dataset_name, emb_name) # Load example lookup dictionary from file. self._example_lookup[lookup_key] = self._load_lookup(lookup_key) # Identify which indices need to be initialized. embeddings_to_index = [ emb_name for emb_name in model_embeddings_names if not self._is_index_initialized(model_name, dataset_name, emb_name) ] # Early exit if all embeddings are now initialized. if not embeddings_to_index: return # Cold start: Get embeddings for non-initialized settings. if self._initialize_new_indices: for res_ix, (result, example) in enumerate( zip(model.predict_with_metadata(examples), examples)): for emb_name in embeddings_to_index: index_key = self._get_index_key(model_name, dataset_name, emb_name) # Initialize saving in the first iteration. if res_ix == 0: file_path = self._get_index_path(index_key) self._indices[index_key].on_disk_build(file_path) index = self._indices.get(index_key) assert index is not None, "Index needs to be created first." # Each item has an incrementing ID res_ix. self._indices[index_key].add_item(res_ix, result[emb_name]) # Add item to lookup table. self._example_lookup[lookup_key][res_ix] = example["data"] # Create the trees from the indices - using 10 as recommended by doc. for emb_name in embeddings_to_index: index_key = self._get_index_key(model_name, dataset_name, emb_name) logging.info("Creating new index: %s", index_key) self._indices[index_key].build(10) index_size = self._indices[index_key].get_n_items() logging.info("Created new index with %s items.", index_size)
def symmetrize_edges(dataset: lit_dataset.Dataset) -> lit_dataset.Dataset: """Symmetrize edges by adding copies with span1 and span2 interchanged.""" def _swap(edge): return lit_dtypes.EdgeLabel(edge.span2, edge.span1, edge.label) edge_fields = utils.find_spec_keys(dataset.spec(), lit_types.EdgeLabels) examples = [] for ex in dataset.examples: new_ex = copy.copy(ex) for field in edge_fields: new_ex[field] += [_swap(edge) for edge in ex[field]] examples.append(new_ex) return lit_dataset.Dataset(dataset.spec(), examples)
def _warm_projections(self, interpreters: List[Text]): """Pre-compute UMAP/PCA projections with default arguments.""" for model, model_info in self._info['models'].items(): for dataset_name in model_info['datasets']: for field_name in utils.find_spec_keys(model_info['spec']['output'], types.Embeddings): config = dict( dataset_name=dataset_name, model_name=model, field_name=field_name, proj_kw={'n_components': 3}) data = {'inputs': [], 'config': config} for interpreter_name in interpreters: _ = self._get_interpretations( data, model, dataset_name, interpreter=interpreter_name)
def find_fields(self, spec: Spec, typ: Type[types.LitType], align_field: Optional[Text] = None) -> List[Text]: # Find fields of provided 'typ'. fields = utils.find_spec_keys(spec, typ) if align_field is None: return fields # Only return fields that are aligned to fields with name specified by # align_field. return [ f for f in fields if getattr(spec[f], "align", None) == align_field ]
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Naively scramble all words in an example.""" del model # Unused. del config # Unused. # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) new_example = copy.deepcopy(example) for text_key in text_keys: new_example[text_key] = self.scramble(example[text_key]) return [new_example]
def find_fields( self, output_spec: Spec, typ: Type[types.LitType], align_typ: Optional[Type[types.LitType]] = None) -> List[Text]: # Find fields of provided 'typ'. fields = utils.find_spec_keys(output_spec, typ) if align_typ is None: return fields # Check that these are aligned to fields of type 'align_typ'. for f in fields: align_field = output_spec[f].align # pytype: disable=attribute-error assert align_field in output_spec, "Align field not in output_spec" assert isinstance(output_spec[align_field], align_typ) return fields
def _create_empty_indices(self, model_name, dataset_name): """Create the empty indices for a model and dataset.""" model = self._models[model_name] examples = self.datasets[dataset_name].indexed_examples model_embeddings_names = utils.find_spec_keys(model.output_spec(), lit_types.Embeddings) if not model_embeddings_names: return # To first create an index, we need to know the shapes - peek at first ex. peeked_example = list(model.predict([examples[0]["data"]]))[0] for emb_name in model_embeddings_names: index_key = self._get_index_key(model_name, dataset_name, emb_name) emb_dimension = len(peeked_example[emb_name]) assert self._indices.get( index_key) is None, "Index already exists." self._indices[index_key] = annoy.AnnoyIndex( emb_dimension, "euclidean")
def _warm_projections(self, interpreters: List[Text]): """Pre-compute UMAP/PCA projections with default arguments.""" for model, model_info in self._info['models'].items(): for dataset_name in model_info['datasets']: embedding_fields = utils.find_spec_keys( model_info['spec']['output'], types.Embeddings) # Only warm-start on the first embedding field, since if models return # many different embeddings this can take a long time. for field_name in embedding_fields[:1]: config = dict(dataset_name=dataset_name, model_name=model, field_name=field_name, proj_kw={'n_components': 3}) data = {'inputs': [], 'config': config} for interpreter_name in interpreters: _ = self._get_interpretations( data, model, dataset_name, interpreter=interpreter_name)
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Naively scramble all words in an example. Note: Even if more than one field is to be scrambled, only a single example will be produced, unlike other generators which will produce multiple examples, one per field. Args: example: the example used for basis of generated examples. model: the model. dataset: the dataset. config: user-provided config properties. Returns: examples: a list of generated examples. """ del model # Unused. config = config or {} # If config key is missing, generate no examples. fields_to_scramble = list(config.get(FIELDS_TO_SCRAMBLE_KEY, [])) if not fields_to_scramble: return [] # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) if not text_keys: return [] text_keys = [key for key in text_keys if key in fields_to_scramble] new_example = copy.deepcopy(example) for text_key in text_keys: new_example[text_key] = self.scramble(example[text_key]) return [new_example]
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Replace words based on replacement list.""" del model # Unused. ignore_casing = config.get('ignore_casing', True) if config else True subs_string = config.get('Substitutions') if config else None if subs_string: replacements = self.parse_subs_string(subs_string, ignore_casing=ignore_casing) else: replacements = self.default_replacements # If replacements dictionary is empty, do not attempt to match. if not replacements: return [] replacement_regex = self._get_replacement_pattern( replacements, ignore_casing=ignore_casing) new_examples = [] # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) for text_key in text_keys: text_data = example[text_key] for new_val in self.generate_counterfactuals( text_data, replacement_regex, replacements, ignore_casing=ignore_casing): new_example = copy.deepcopy(example) new_example[text_key] = new_val new_examples.append(new_example) return new_examples
def _get_preds(self, data, model: Text, dataset_name: Optional[Text] = None, requested_types: Text = 'LitType', **unused_kw): """Get model predictions. Args: data: data payload, containing 'inputs' field model: name of the model to run dataset_name: name of the active dataset requested_types: optional, comma-separated list of types to return Returns: List[JsonDict] containing requested fields of model predictions """ preds = list(self._predict(data['inputs'], model, dataset_name)) # Figure out what to return to the frontend. output_spec = self._get_spec(model)['output'] requested_types = requested_types.split(',') logging.info('Requested types: %s', str(requested_types)) ret_keys = [] for t_name in requested_types: t_class = getattr(types, t_name, None) assert issubclass( t_class, types.LitType), f"Class '{t_name}' is not a valid LitType." ret_keys.extend(utils.find_spec_keys(output_spec, t_class)) ret_keys = set(ret_keys) # de-dupe # Return selected keys. logging.info('Will return keys: %s', str(ret_keys)) # One record per input. ret = [utils.filter_by_keys(p, ret_keys.__contains__) for p in preds] return ret
def run(self, inputs: List[JsonDict], dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None): """Run generation on a set of inputs. Args: inputs: sequence of inputs, following dataset.spec() dataset: dataset, used to access dataset.spec() config: additional runtime options Returns: list of list of new generated inputs, following dataset.spec() """ all_outputs = [[] for _ in inputs] config = config or {} # Find text fields. text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment) # If config key is missing, backtranslate all text fields. fields_to_backtranslate = list( config.get(FIELDS_TO_BACKTRANSLATE_KEY, text_fields)) candidates_by_field = {} for field_name in fields_to_backtranslate: texts = [ex[field_name] for ex in inputs] candidates_by_field[field_name] = self.generate_from_texts(texts) # Generate by substituting in each field. # TODO(lit-team): substitute on a combination of fields? for field_name in candidates_by_field: candidates = candidates_by_field[field_name] for i, ex in enumerate(inputs): for candidate in candidates[i]: new_ex = utils.copy_and_update(ex, {field_name: candidate}) all_outputs[i].append(new_ex) return all_outputs
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, kernel_width: int = 25, # TODO(lit-dev): make configurable in UI. mask_string: str = '[MASK]', # TODO(lit-dev): make configurable in UI. num_samples: int = 256, # TODO(lit-dev): make configurable in UI. ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys(model.output_spec(), types.MulticlassPreds) if not pred_keys: logging.warning( 'LIME did not find a multi-class predictions field.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key]) label_names = pred_spec.vocab # Create a LIME text explainer instance. explainer = lime_text.LimeTextExplainer( class_names=label_names, split_expression=str.split, kernel_width=kernel_width, mask_string=mask_string, # This is the string used to mask words. bow=False ) # bow=False masks inputs, instead of deleting them entirely. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] logging.info('Explaining: %s', input_string) # Use the number of words as the number of features. num_features = len(input_string.split()) def _predict_proba(strings: List[Text]): """Given raw strings, return probabilities. Used by `explainer`.""" input_examples = [ new_example(input_, text_key, s) for s in strings ] model_outputs = model.predict(input_examples) probs = np.array( [output[pred_key] for output in model_outputs]) return probs # <float32>[len(strings), num_labels] # Perturbs the input string, gets model predictions, fits linear model. explanation = explainer.explain_instance( input_string, _predict_proba, num_features=num_features, num_samples=num_samples) # Turn the LIME explanation into a list following original word order. scores = explanation_to_array(explanation) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results
def input_spec(self): ret = glue_models.STSBModel.input_spec(self) token_keys = utils.find_spec_keys(ret, lit_types.Tokens) for k in token_keys: ret.pop(k, None) return ret
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" config_defaults = {k: v.default for k, v in self.config_spec().items()} config = dict(config_defaults, **(config or {})) # update and return provided_class_to_explain = int(config[CLASS_KEY]) kernel_width = int(config[KERNEL_WIDTH_KEY]) num_samples = int(config[NUM_SAMPLES_KEY]) mask_string = (config[MASK_KEY]) # pylint: disable=g-explicit-bool-comparison seed = int(config[SEED_KEY]) if config[SEED_KEY] != '' else None # pylint: enable=g-explicit-bool-comparison # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore, types.SparseMultilabelPreds)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = config[TARGET_HEAD_KEY] or pred_keys[0] all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial( _predict_fn, model=model, original_example=input_, pred_key=pred_key, pred_type_info=model.output_spec()[pred_key]) class_to_explain = get_class_to_explain(provided_class_to_explain, model, pred_key, input_) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain=class_to_explain, num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.TokenSalience( input_string.split(), scores) all_results.append(result) return all_results
def find_fields(self, model: lit_model.Model) -> Dict[str, List[str]]: src_fields = utils.find_spec_keys(model.input_spec(), types.TextSegment) gen_fields = utils.find_spec_keys( model.output_spec(), (types.GeneratedText, types.GeneratedTextCandidates)) return {f: src_fields for f in gen_fields}
def run(self, inputs: List[types.JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[types.JsonDict]] = None, config: Optional[types.JsonDict] = None): """Create PDP chart info using provided inputs. Args: inputs: sequence of inputs, following model.input_spec() model: optional model to use to generate new examples. dataset: dataset which the current examples belong to. model_outputs: optional precomputed model outputs config: optional runtime config. Returns: a dict of alternate feature values to model outputs. The model outputs will be a number for regression models and a list of numbers for multiclass models. """ pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('PDP did not find any supported output fields.') return None assert 'feature' in config, 'No feature to test provided' feature = config['feature'] provided_range = config['range'] if 'range' in config else [] edited_outputs = {} for pred_key in pred_keys: edited_outputs[pred_key] = {} # If a range was provided, use that to create the possible values. vals_to_test = (np.linspace(provided_range[0], provided_range[1], 10) if len(provided_range) == 2 else self.get_vals_to_test( feature, dataset)) # If no specific inputs provided, use the entire dataset. inputs_to_use = inputs if inputs else dataset.examples # For each alternate value for a given feature. for new_val in vals_to_test: # Create copies of all provided inputs with the value replaced. edited_inputs = [] for inp in inputs_to_use: edited_input = copy.deepcopy(inp) edited_input[feature] = new_val edited_inputs.append(edited_input) # Run prediction on the altered inputs. outputs = list(model.predict(edited_inputs)) # Store the mean of the prediction for the alternate value. for pred_key in pred_keys: numeric = isinstance(model.output_spec()[pred_key], types.RegressionScore) if numeric: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs]) else: edited_outputs[pred_key][new_val] = np.mean( [output[pred_key] for output in outputs], axis=0) return edited_outputs
def is_compatible(self, model: lit_model.Model): text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) return len(text_keys) and len(pred_keys)
def run( self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None, ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT kernel_width = int( config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT mask_string = config[MASK_KEY] if config else MASK_DEFAULT num_samples = int( config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT seed = config[SEED_KEY] if config else SEED_DEFAULT # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are # used by the model. text_keys = utils.find_spec_keys(model.input_spec(), types.TextSegment) if not text_keys: logging.warning('LIME requires text inputs.') return None logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). pred_keys = utils.find_spec_keys( model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: logging.warning('LIME did not find any supported output fields.') return None pred_key = pred_keys[ 0] # TODO(lit-dev): configure which prob field to use. all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} predict_fn = functools.partial(_predict_fn, model=model, original_example=input_, pred_key=pred_key) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] if not input_string: logging.info('Could not explain empty string for %s', text_key) continue logging.info('Explaining: %s', input_string) # Perturbs the input string, gets model predictions, fits linear model. explanation = lime.explain( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. class_to_explain= class_to_explain, # Index of the class to explain. num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, kernel=functools.partial(lime.exponential_kernel, kernel_width=kernel_width), seed=seed) # Turn the LIME explanation into a list following original word order. scores = explanation.feature_importance # TODO(lit-dev): Move score normalization to the UI. scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) return all_results