def test_find_all_combinations_min_is_negative(self): l = [1, 2, 3, 4] combinations = utils.find_all_combinations( l, min_element_count=-1, max_element_count=2) expected = [[1], [2], [3], [4], [1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]] self.assertListEqual(combinations, expected)
def test_find_all_combinations_max_is_greater_than_len(self): l = [1, 2, 3, 4] combinations = utils.find_all_combinations( l, min_element_count=2, max_element_count=10) expected = [[1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4], [1, 2, 3], [1, 2, 4], [1, 3, 4], [2, 3, 4], [1, 2, 3, 4]] self.assertListEqual(combinations, expected)
def test_find_all_combinations_min_is_greater_than_max(self): l = [1, 2, 3, 4] combinations = utils.find_all_combinations( l, min_element_count=3, max_element_count=2) expected = [] self.assertListEqual(combinations, expected)
def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Optional[JsonDict] = None) -> List[JsonDict]: # Perform validation and retrieve configuration. if not model: raise ValueError('Please provide a model for this generator.') config = config or {} num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) pred_key = config.get(PREDICTION_KEY, '') regression_thresh = float( config.get(REGRESSION_THRESH_KEY, REGRESSION_THRESH_DEFAULT)) dataset_name = config.get('dataset_name') if not dataset_name: raise ValueError('The dataset name must be in the config.') output_spec = model.output_spec() if not pred_key: raise ValueError('Please provide the prediction key.') if pred_key not in output_spec: raise ValueError('Invalid prediction key.') if (not (isinstance(output_spec[pred_key], lit_types.MulticlassPreds) or isinstance(output_spec[pred_key], lit_types.RegressionScore))): raise ValueError( 'Only classification and regression models are supported') # Calculate dataset statistics if it has never been calculated. The # statistics include such information as 'standard deviation' for scalar # features and probabilities for categorical features. if dataset_name not in self._datasets_stats: self._calculate_stats(dataset, dataset_name) # Find predicted class of the original example. original_pred = list(model.predict([example]))[0] # Find dataset examples that are flips. filtered_examples = self._filter_ds_examples( dataset=dataset, dataset_name=dataset_name, model=model, reference_output=original_pred, pred_key=pred_key, regression_thresh=regression_thresh) supported_field_names = self._find_all_fields_to_consider( ds_spec=dataset.spec(), model_input_spec=model.input_spec(), example=example) candidates: List[JsonDict] = [] # Iterate through all possible feature combinations. combs = utils.find_all_combinations(supported_field_names, 1, max_flips) for comb in combs: # Sort all dataset examples with respect to the given combination. sorted_examples = self._sort_and_filter_examples( examples=filtered_examples, ref_example=example, fields=comb, dataset=dataset, dataset_name=dataset_name) if not sorted_examples: continue # As an optimization trick, check whether the farthest example is a flip. # If it is not a flip then skip the current combination of features. # This optimization makes the minimum set guarantees weaker but # significantly improves the search speed. flip = self._find_hot_flip(ref_example=example, ds_example=sorted_examples[-1], features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=False, regression_threshold=regression_thresh) if not flip: logging.info('Skipped combination %s', comb) continue # Iterate through the sorted examples until the first flip is found. # TODO(b/204200758): improve performance by batching the predict requests. for ds_example in sorted_examples: flip = self._find_hot_flip( ref_example=example, ds_example=ds_example, features_to_consider=comb, model=model, target_pred=original_pred, pred_key=pred_key, dataset=dataset, interpolate=True, regression_threshold=regression_thresh) if flip: self._add_if_not_strictly_worse(example=flip, other_examples=candidates, ref_example=example, dataset=dataset, dataset_name=dataset_name, model=model) break if len(candidates) >= num_examples: break # Calculate distances for the found hot flips. candidate_tuples = [] for flip_example in candidates: distance, diff_fields = self._calculate_L1_distance( example_1=example, example_2=flip_example, dataset=dataset, dataset_name=dataset_name, model=model) if distance > 0: candidate_tuples.append((distance, diff_fields, flip_example)) # Order the dataset entries based on the distance to the given example. candidate_tuples.sort(key=lambda e: e[0]) if len(candidate_tuples) > num_examples: candidate_tuples = candidate_tuples[0:num_examples] # e[2] contains the hot-flip examples in the distances list of tuples. return [e[2] for e in candidate_tuples]