def __init__(self, inputs: lit_dataset.Dataset, preds: lit_dataset.Dataset, input_identifier_keys: Optional[List[str]] = None): """Build a static index. Args: inputs: a lit Dataset preds: a lit Dataset, parallel to inputs input_identifier_keys: (optional), list of keys to treat as identifiers for matching inputs. If None, will use all fields in inputs.spec() """ self._output_spec = preds.spec() self._input_spec = inputs.spec() self.input_identifier_keys = input_identifier_keys or self._input_spec.keys( ) # Filter to only the identifier keys self._input_spec = { k: self._input_spec[k] for k in self.input_identifier_keys } # Build the index for prediction lookups self._index = { self.key_fn(ex): pred for ex, pred in zip(inputs.examples, preds.examples) }
def symmetrize_edges(dataset: lit_dataset.Dataset) -> lit_dataset.Dataset: """Symmetrize edges by adding copies with span1 and span2 interchanged.""" def _swap(edge): return lit_dtypes.EdgeLabel(edge.span2, edge.span1, edge.label) edge_fields = utils.find_spec_keys(dataset.spec(), lit_types.EdgeLabels) examples = [] for ex in dataset.examples: new_ex = copy.copy(ex) for field in edge_fields: new_ex[field] += [_swap(edge) for edge in ex[field]] examples.append(new_ex) return lit_dataset.Dataset(dataset.spec(), examples)
def _calculate_stats(self, dataset: lit_dataset.Dataset, dataset_name: Text) -> None: # Iterate through all examples in the dataset and store column values # in individual lists to facilitate future computation. field_values = {} spec = dataset.spec() supported_fields = [ name for name in spec if self._is_supported(spec[name]) ] for example in dataset.examples: for field_name in supported_fields: if example[field_name] is None: continue if field_name not in field_values: field_values[field_name] = [] field_values[field_name].append(example[field_name]) # Compute the necessary statistics: standard deviation for scalar fields and # probability of having same value for categorical and categorical fields. field_stats = {} for field_name, values in field_values.items(): field_spec = spec[field_name] if self._is_scalar(field_spec): field_stats[field_name] = self._calculate_std_dev(values) elif self._is_categorical(field_spec): field_stats[field_name] = self._calculate_categorical_prob( values) else: assert False, 'Should never be reached.' # Cache the stats for the given dataset. self._datasets_stats[dataset_name] = field_stats
def run(self, inputs: List[JsonDict], dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None): """Run generation on a set of inputs. Args: inputs: sequence of inputs, following dataset.spec() dataset: dataset, used to access dataset.spec() config: additional runtime options Returns: list of list of new generated inputs, following dataset.spec() """ all_outputs = [[] for _ in inputs] # Find text fields text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment) # TODO(lit-team): configure a subset of fields to operate on candidates_by_field = {} for field_name in text_fields: texts = [ex[field_name] for ex in inputs] candidates_by_field[field_name] = self.generate_from_texts(texts) # Generate by substituting in each field. # TODO(lit-team): substitute on a combination of fields? for field_name in candidates_by_field: candidates = candidates_by_field[field_name] for i, ex in enumerate(inputs): for candidate in candidates[i]: new_ex = utils.copy_and_update(ex, {field_name: candidate}) all_outputs[i].append(new_ex) return all_outputs
def annotate(self, inputs: List[JsonDict], dataset: lit_dataset.Dataset, dataset_spec_to_annotate: Optional[types.Spec] = None): if len(self._annotator_model.input_spec().items()) != 1: raise ValueError( 'Annotator model provided to PerFieldAnnotator does not ' 'operate on a single field') datasets = {} for input_name, input_type in self._annotator_model.input_spec().items( ): # Do remap of inputs based on input name needed by annotator. ds_keys = utils.find_spec_keys(dataset.spec(), type(input_type)) for ds_key in ds_keys: temp_ds = lit_dataset.Dataset(examples=inputs, base=dataset) datasets[ds_key] = temp_ds.remap({ds_key: input_name}) for ds_key, ds in datasets.items(): outputs = self._annotator_model.predict(ds.examples) for output_name, output_type in self._annotator_model.output_spec( ).items(): # Update dataset spec with new annotated field. field_name = f'{self._name}:{output_name}:{ds_key}' if dataset_spec_to_annotate: dataset_spec_to_annotate[field_name] = attr.evolve( output_type, annotated=True) # Update all examples with annotator output. for example, output in zip(inputs, outputs): example[field_name] = output[output_name]
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Replace words based on replacement list.""" del model # Unused. subs_string = config.get('subs') if config else None if subs_string: replacements = self.parse_subs_string(subs_string) else: replacements = self.default_replacements new_examples = [] # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) for text_key in text_keys: text_data = example[text_key] token_spans = map(lambda x: x.span(), self.tokenization_pattern.finditer(text_data)) for new_val in self.generate_counterfactuals( text_data, token_spans, replacements): new_example = copy.deepcopy(example) new_example[text_key] = new_val new_examples.append(new_example) return new_examples
def _run_annotators(self, dataset: lit_dataset.Dataset) -> lit_dataset.Dataset: datapoints = [dict(ex) for ex in dataset.examples] annotated_spec = dict(dataset.spec()) for annotator in self._annotators: annotator.annotate(datapoints, dataset, annotated_spec) return lit_dataset.Dataset(base=dataset, examples=datapoints, spec=annotated_spec)
def run_with_metadata(self, indexed_inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> List[JsonDict]: # TODO(lit-team): pre-compute this mapping in constructor? # This would require passing a model name to this function so we can # reference a pre-computed list. spec = model.spec() field_map = map_pred_keys(dataset.spec(), spec.output, self.is_compatible) ret = [] for pred_key, label_key in field_map.items(): # Extract fields labels = [ex['data'][label_key] for ex in indexed_inputs] preds = [mo[pred_key] for mo in model_outputs] indices = [ex['id'] for ex in indexed_inputs] metas = [ex['meta'] for ex in indexed_inputs] # Compute metrics, as dict(str -> float) metrics = self.compute_with_metadata( labels, preds, label_spec=dataset.spec()[label_key], pred_spec=spec.output[pred_key], indices=indices, metas=metas, config=config.get(label_key) if config else None) # NaN is not a valid JSON value, so replace with None which will be # serialized as null. # TODO(lit-team): move this logic into serialize.py somewhere instead? metrics = { k: (v if not np.isnan(v) else None) for k, v in metrics.items() } # Format for frontend. ret.append({ 'pred_key': pred_key, 'label_key': label_key, 'metrics': metrics }) return ret
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None): if model_outputs is None: model_outputs = list(model.predict(inputs)) spec = model.spec() field_map = map_pred_keys(dataset.spec(), spec.output, self.is_compatible) ret = [] for pred_key, label_key in field_map.items(): # Extract fields labels = [ex[label_key] for ex in inputs] preds = [mo[pred_key] for mo in model_outputs] # Compute metrics, as dict(str -> float) metrics = self.compute( labels, preds, label_spec=dataset.spec()[label_key], pred_spec=spec.output[pred_key], config=config.get(pred_key) if config else None) # NaN is not a valid JSON value, so replace with None which will be # serialized as null. # TODO(lit-team): move this logic into serialize.py somewhere instead? metrics = { k: (v if not np.isnan(v) else None) for k, v in metrics.items() } # Format for frontend. ret.append({ 'pred_key': pred_key, 'label_key': label_key, 'metrics': metrics }) return ret
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Naively scramble all words in an example.""" del model # Unused. del config # Unused. # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) new_example = copy.deepcopy(example) for text_key in text_keys: new_example[text_key] = self.scramble(example[text_key]) return [new_example]
def run(self, inputs: List[JsonDict], model: lit_model.Model, dataset: lit_dataset.Dataset, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None): # Get margin for each input for each pred key and add them to a config dict # to pass to the wrapped metrics. field_map = map_pred_keys(dataset.spec(), model.spec().output, self.is_compatible) margin_config = {} for pred_key in field_map: field_config = config.get(pred_key) if config else None margins = [ get_margin_for_input(field_config, inp) for inp in inputs ] margin_config[pred_key] = margins return self._metrics.run(inputs, model, dataset, model_outputs, margin_config)
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Naively scramble all words in an example. Note: Even if more than one field is to be scrambled, only a single example will be produced, unlike other generators which will produce multiple examples, one per field. Args: example: the example used for basis of generated examples. model: the model. dataset: the dataset. config: user-provided config properties. Returns: examples: a list of generated examples. """ del model # Unused. config = config or {} # If config key is missing, generate no examples. fields_to_scramble = list(config.get(FIELDS_TO_SCRAMBLE_KEY, [])) if not fields_to_scramble: return [] # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) if not text_keys: return [] text_keys = [key for key in text_keys if key in fields_to_scramble] new_example = copy.deepcopy(example) for text_key in text_keys: new_example[text_key] = self.scramble(example[text_key]) return [new_example]
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Replace words based on replacement list.""" del model # Unused. ignore_casing = config.get('ignore_casing', True) if config else True subs_string = config.get('Substitutions') if config else None if subs_string: replacements = self.parse_subs_string(subs_string, ignore_casing=ignore_casing) else: replacements = self.default_replacements # If replacements dictionary is empty, do not attempt to match. if not replacements: return [] replacement_regex = self._get_replacement_pattern( replacements, ignore_casing=ignore_casing) new_examples = [] # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) for text_key in text_keys: text_data = example[text_key] for new_val in self.generate_counterfactuals( text_data, replacement_regex, replacements, ignore_casing=ignore_casing): new_example = copy.deepcopy(example) new_example[text_key] = new_val new_examples.append(new_example) return new_examples
def run(self, inputs: List[JsonDict], dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None): """Run generation on a set of inputs. Args: inputs: sequence of inputs, following dataset.spec() dataset: dataset, used to access dataset.spec() config: additional runtime options Returns: list of list of new generated inputs, following dataset.spec() """ all_outputs = [[] for _ in inputs] config = config or {} # Find text fields. text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment) # If config key is missing, backtranslate all text fields. fields_to_backtranslate = list( config.get(FIELDS_TO_BACKTRANSLATE_KEY, text_fields)) candidates_by_field = {} for field_name in fields_to_backtranslate: texts = [ex[field_name] for ex in inputs] candidates_by_field[field_name] = self.generate_from_texts(texts) # Generate by substituting in each field. # TODO(lit-team): substitute on a combination of fields? for field_name in candidates_by_field: candidates = candidates_by_field[field_name] for i, ex in enumerate(inputs): for candidate in candidates[i]: new_ex = utils.copy_and_update(ex, {field_name: candidate}) all_outputs[i].append(new_ex) return all_outputs
def _calculate_L1_distance( self, example_1: JsonDict, example_2: JsonDict, dataset: lit_dataset.Dataset, dataset_name: Text, model: Optional[lit_model.Model] = None, field_names: Optional[List[Text]] = None ) -> Tuple[float, List[Text]]: """Calculates L1 distance between two input examples. Only categorical and scalar example features are considered. For categorical features, the distance is calculated as the probability of the feature having the same for two random (with replacement) examples. For scalar features, the unit of distance is equal to the standard deviation of all feature values. Only features that are in the intersection of the model and dataset features are considered. If a feature value of either of the examples is None, such feature is ignored in distance calculation and the name of the feature is not included in the result feature list (see Returns description). Args: example_1: a first example to measure distance for. example_2: a second example to measure distance for. dataset: a dataset that contains the information about the feature types. dataset_name: name of the dataset. model: a model that contains the information about the input feature types. field_names: if set then the distance calculation only considers these fields. Returns: A tuple that contains the L1 distance and the list of features that were used in the distance calculation. The list of features will only contain """ assert model or field_names distance = 0 diff_fields = [] if field_names is None: assert model field_names = self._find_all_fields_to_consider( ds_spec=dataset.spec(), model_input_spec=model.input_spec()) for field_name in field_names: field_spec = dataset.spec()[field_name] field_stats = self._datasets_stats[dataset_name] assert self._is_supported(field_spec) assert field_name in field_stats, f'{field_name}, {field_stats.keys()}' if example_1[field_name] == example_2[field_name]: continue if (example_1[field_name] is None) or (example_2[field_name] is None): continue diff_fields.append(field_name) if self._is_scalar(field_spec): std_dev = field_stats[field_name] if std_dev != 0: distance += abs(example_1[field_name] - example_2[field_name]) / std_dev else: same_prob = field_stats[field_name] distance += same_prob return distance, diff_fields
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: """Replace words based on replacement list. Note: If multiple fields are selected for replacement, this method will generate an example per field. For example, if there are two fields on which to perform replacement, the method will perform replacement first on one field to produce an example (other fields left intact), and then perform replacement on the second field (again copying all other fields from the original datum). Args: example: the example used for basis of generated examples. model: the model. dataset: the dataset. config: user-provided config properties. Returns: examples: a list of generated examples. """ del model # Unused. config = config or {} ignore_casing = config.get(IGNORE_CASING_KEY, True) subs_string = config.get(SUBSTITUTIONS_KEY, None) if subs_string: replacements = self.parse_subs_string(subs_string, ignore_casing=ignore_casing) else: replacements = self.default_replacements # If replacements dictionary is empty, do not attempt to match. if not replacements: return [] replacement_regex = self._get_replacement_pattern( replacements, ignore_casing=ignore_casing) # If config key is missing, generate no examples. fields_to_replace = list(config.get(FIELDS_TO_REPLACE_KEY, [])) if not fields_to_replace: return [] # TODO(lit-dev): move this to generate_all(), so we read the spec once # instead of on every example. text_keys = utils.find_spec_keys(dataset.spec(), types.TextSegment) if not text_keys: return [] text_keys = [key for key in text_keys if key in fields_to_replace] new_examples = [] for text_key in text_keys: text_data = example[text_key] for new_val in self.generate_counterfactuals( text_data, replacement_regex, replacements, ignore_casing=ignore_casing): new_example = copy.deepcopy(example) new_example[text_key] = new_val new_examples.append(new_example) return new_examples
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: # Perform validation and retrieve configuration. if not model: raise ValueError('Please provide a model for this generator.') config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) pred_key = config.get(PREDICTION_KEY, '') regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) dataset_name = config.get('dataset_name') if not dataset_name: raise ValueError('The dataset name must be in the config.') output_spec = model.output_spec() if not pred_key: raise ValueError('Please provide the prediction key.') if pred_key not in output_spec: raise ValueError('Invalid prediction key.') if (not (isinstance(output_spec[pred_key], lit_types.MulticlassPreds) or isinstance(output_spec[pred_key], lit_types.RegressionScore))): raise ValueError( 'Only classification and regression models are supported') # Calculate dataset statistics if it has never been calculated. The # statistics include such information as 'standard deviation' for scalar # features and probabilities for categorical features. if dataset_name not in self._datasets_stats: self._calculate_stats(dataset, dataset_name) # Find predicted class of the original example. original_pred = list(model.predict([example]))[0] # Find dataset examples that are flips. filtered_examples = self._filter_ds_examples( dataset=dataset, dataset_name=dataset_name, model=model, reference_output=original_pred, pred_key=pred_key, regression_thresh=regression_thresh) supported_field_names = self._find_all_fields_to_consider( ds_spec=dataset.spec(), model_input_spec=model.input_spec(), example=example) candidates: List[JsonDict] = [] # Iterate through all possible feature combinations. combs = utils.find_all_combinations(supported_field_names, 1, max_flips) for comb in combs: # Sort all dataset examples with respect to the given combination. sorted_examples = self._sort_and_filter_examples( examples=filtered_examples, ref_example=example, fields=comb, dataset=dataset, dataset_name=dataset_name) if not sorted_examples: continue # As an optimization trick, check whether the farthest example is a flip. # If it is not a flip then skip the current combination of features. # This optimization makes the minimum set guarantees weaker but # significantly improves the search speed. flip = self._find_hot_flip(ref_example=example, ds_example=sorted_examples[-1], features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=False, regression_threshold=regression_thresh) if not flip: logging.info('Skipped combination %s', comb) continue # Iterate through the sorted examples until the first flip is found. # TODO(b/204200758): improve performance by batching the predict requests. for ds_example in sorted_examples: flip = self._find_hot_flip( ref_example=example, ds_example=ds_example, features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=True, regression_threshold=regression_thresh) if flip: self._add_if_not_strictly_worse(example=flip, other_examples=candidates, ref_example=example, dataset=dataset, dataset_name=dataset_name, model=model) break if len(candidates) >= num_examples: break # Calculate distances for the found hot flips. candidate_tuples = [] for flip_example in candidates: distance, diff_fields = self._calculate_L1_distance( example_1=example, example_2=flip_example, dataset=dataset, dataset_name=dataset_name, model=model) if distance > 0: candidate_tuples.append((distance, diff_fields, flip_example)) # Order the dataset entries based on the distance to the given example. candidate_tuples.sort(key=lambda e: e[0]) if len(candidate_tuples) > num_examples: candidate_tuples = candidate_tuples[0:num_examples] # e[2] contains the hot-flip examples in the distances list of tuples. return [e[2] for e in candidate_tuples]
def _find_hot_flip( self, ref_example: JsonDict, ds_example: JsonDict, features_to_consider: List[Text], model: lit_model.Model, target_pred: JsonDict, pred_key: Text, dataset: lit_dataset.Dataset, interpolate: bool, regression_threshold: Optional[float] = None, ) -> Optional[JsonDict]: """Finds a hot-flip example for a given target example and DS example. Args: ref_example: target example for which the counterfactuals should be found. ds_example: a dataset example that should be used as a starting point for the search. features_to_consider: the list of feature keys that can be changed during the search. model: model to use for getting predictions. target_pred: model prediction that corresponds to `ref_example`. pred_key: the name of the field in model predictions that contains the prediction value for the counterfactual search. dataset: a dataset object that contains `ds_example`. interpolate: if True, the method tries to find a closer counterfactual using interpolation. regression_threshold: the threshold to use if `model` is a regression model. This parameter is ignored for classification models. Returns: A hot-flip counterfactual that satisfy the criteria. """ # All features other than `features_to_consider` should be assigned the # value of the target example. candidate_example = ds_example.copy() for field_name in ref_example: if (field_name not in features_to_consider and field_name in model.input_spec()): candidate_example[field_name] = ref_example[field_name] flip, predicted_value = self._is_flip( model=model, cf_example=candidate_example, orig_output=target_pred, pred_key=pred_key, regression_thresh=regression_threshold) if not flip: return None # Find closest flip by moving scalar values closer to the target. closest_flip = None if interpolate: closest_flip = self._find_closer_flip_using_interpolation( ref_example, candidate_example, target_pred, pred_key, model, dataset, regression_threshold) # If we found a closer flip through interpolation then use it, # otherwise use the previously found flip. if closest_flip is not None: return closest_flip else: self._find_dataset_parent_and_set( model_output_spec=model.output_spec(), pred_key=pred_key, dataset_spec=dataset.spec(), example=candidate_example, predicted_value=predicted_value) return candidate_example
def _find_closer_flip_using_interpolation( self, ref_example: JsonDict, known_flip: JsonDict, target_pred: JsonDict, pred_key: Text, model: lit_model.Model, dataset: lit_dataset.Dataset, regression_threshold: Optional[float] = None, max_attempts: int = 4) -> Optional[JsonDict]: """Looks for the decision boundary between two examples using interpolation. The method searches for a flip that is closer to the `target example` than `known_flip`. The method performs the binary search by interpolating scalar values. Args: ref_example: an example for which the flip is searched. known_flip: an example that represents a known flip. target_pred: the model prediction at `ref_example`. pred_key: the named of the field inside `target_pred` that holds the prediction value. model: model to use for running predictions. dataset: dataset that contains `known_flip`. regression_threshold: threshold to use for regression models. max_attempts: number of binary search attempts. Returns: The counterfactual (flip) if found; 'None' otherwise. """ min_alpha = 0.0 max_alpha = 1.0 closest_flip = None input_spec = model.input_spec() has_scalar = False for _ in range(max_attempts): # Interpolate the scalar values using binary search. current_alpha = (min_alpha + max_alpha) / 2 candidate = known_flip.copy() for field in ref_example: if (field in candidate and field in input_spec and isinstance(input_spec[field], lit_types.Scalar) and candidate[field] is not None and ref_example[field] is not None): candidate[field] = known_flip[field] * ( 1 - current_alpha) + ref_example[field] * current_alpha has_scalar = True # The interpolation makes sense only for scalar values. If there are no # scalar fields that can be interpolated then terminate the search. if not has_scalar: return None flip, predicted_value = self._is_flip( model=model, cf_example=candidate, orig_output=target_pred, pred_key=pred_key, regression_thresh=regression_threshold) if flip: self._find_dataset_parent_and_set( model_output_spec=model.output_spec(), pred_key=pred_key, dataset_spec=dataset.spec(), example=candidate, predicted_value=predicted_value) closest_flip = candidate min_alpha = current_alpha else: max_alpha = current_alpha return closest_flip