def annotate(self, inputs: List[JsonDict], dataset: lit_dataset.Dataset, dataset_spec_to_annotate: Optional[types.Spec] = None): if len(self._annotator_model.input_spec().items()) != 1: raise ValueError( 'Annotator model provided to PerFieldAnnotator does not ' 'operate on a single field') datasets = {} for input_name, input_type in self._annotator_model.input_spec().items( ): # Do remap of inputs based on input name needed by annotator. ds_keys = utils.find_spec_keys(dataset.spec(), type(input_type)) for ds_key in ds_keys: temp_ds = lit_dataset.Dataset(examples=inputs, base=dataset) datasets[ds_key] = temp_ds.remap({ds_key: input_name}) for ds_key, ds in datasets.items(): outputs = self._annotator_model.predict(ds.examples) for output_name, output_type in self._annotator_model.output_spec( ).items(): # Update dataset spec with new annotated field. field_name = f'{self._name}:{output_name}:{ds_key}' if dataset_spec_to_annotate: dataset_spec_to_annotate[field_name] = attr.evolve( output_type, annotated=True) # Update all examples with annotator output. for example, output in zip(inputs, outputs): example[field_name] = output[output_name]
def _run_annotators(self, dataset: lit_dataset.Dataset) -> lit_dataset.Dataset: datapoints = [dict(ex) for ex in dataset.examples] annotated_spec = dict(dataset.spec()) for annotator in self._annotators: annotator.annotate(datapoints, dataset, annotated_spec) return lit_dataset.Dataset(base=dataset, examples=datapoints, spec=annotated_spec)
def symmetrize_edges(dataset: lit_dataset.Dataset) -> lit_dataset.Dataset: """Symmetrize edges by adding copies with span1 and span2 interchanged.""" def _swap(edge): return lit_dtypes.EdgeLabel(edge.span2, edge.span1, edge.label) edge_fields = utils.find_spec_keys(dataset.spec(), lit_types.EdgeLabels) examples = [] for ex in dataset.examples: new_ex = copy.copy(ex) for field in edge_fields: new_ex[field] += [_swap(edge) for edge in ex[field]] examples.append(new_ex) return lit_dataset.Dataset(dataset.spec(), examples)
def test_interpreter(self): interpreter = image_gradient_maps.VanillaGradients() model = ClassificationTestModel() self.assertTrue(interpreter.is_compatible(model)) pil_image = PILImage.new(mode='RGB', size=(300, 200)) inp = {'image': image_utils.convert_pil_to_image_str(pil_image)} run_output = interpreter.run(inputs=[inp], model=model, dataset=lit_dataset.Dataset())[0] self.assertIn('grads', run_output) self.assertIsInstance(run_output['grads'], str)
def test_gradient_maps(self): self.ig = gradient_maps.IntegratedGradients() # Basic test with dummy outputs from the model. inputs = [{'segment': '_'}] model = testing_utils.TestModelClassification() dataset = lit_dataset.Dataset(None, None) output = self.ig.run(inputs, model, dataset) self.assertLen(output, 1) salience = output[0]['input_embs_grad'].salience target = np.array([0.25, 0.25, 0.25, 0.25]) self.assertTrue((salience == target).all())
def _annotate_new_data(self, data, dataset_name: Optional[Text] = None, **unused_kw) -> List[IndexedInput]: """Fill in index and other extra data for the provided datapoints.""" # TODO(lit-dev): unify this with hash fn on dataset objects. assert dataset_name is not None, 'No dataset specified.' # Generate annotated versions of new datapoints. dataset = self._datasets[dataset_name] input_examples = [example['data'] for example in data['inputs']] dataset_to_annotate = lit_dataset.Dataset(base=dataset, examples=input_examples) annotated_dataset = self._run_annotators(dataset_to_annotate) # Add annotations and IDs to new datapoints. for i, example in enumerate(data['inputs']): example['data'] = annotated_dataset.examples[i] example['id'] = caching.input_hash(example['data']) return data['inputs']
def test_regression(self): interpreter = image_gradient_maps.GuidedIG() model = RegressionTestModel() self.assertTrue(interpreter.is_compatible(model)) input_image_array = np.zeros(shape=[20, 15, 3], dtype=np.uint8) input_image_array[0, 0, 0] = 10 input_image_array[1, 0, 0] = 20 pil_image = PILImage.fromarray(input_image_array, mode='RGB') inp = {'image': image_utils.convert_pil_to_image_str(pil_image)} run_output = interpreter.run(inputs=[inp], model=model, dataset=lit_dataset.Dataset())[0] self.assertIn('grads', run_output) overlay_str = run_output['grads'] overlay_bytes = image_utils.convert_image_str_to_array( overlay_str, shape=RegressionTestModel.GRADIENT_SHAPE) self.assertIsNotNone(overlay_bytes) self.assertSequenceEqual(overlay_bytes.shape, RegressionTestModel.GRADIENT_SHAPE)
def test_remap(self): """Test remap method.""" spec = { "score": types.Scalar(), "text": types.TextSegment(), } datapoints = [ { "score": 0, "text": "a" }, { "score": 0, "text": "b" }, ] dset = lit_dataset.Dataset(spec, datapoints) remap_dict = {"score": "val", "nothing": "nada"} remapped_dset = dset.remap(remap_dict) self.assertIn("val", remapped_dset.spec()) self.assertNotIn("score", remapped_dset.spec()) self.assertEqual({"val": 0, "text": "a"}, remapped_dset.examples[0])
def create_train_dataset(config: Config) -> lit_dataset.Dataset: src_path = config.exp_dir / "train.src.txt" trg_path = config.exp_dir / "train.trg.txt" default_src_iso = config.default_src_iso default_trg_iso = config.default_trg_iso examples: List[lit_types.JsonDict] = [] with src_path.open("r", encoding="utf-8") as src_file, open( trg_path, "r", encoding="utf-8") as trg_file: for src_line, trg_line in zip(src_file, trg_file): src_line = src_line.strip() trg_line = trg_line.strip() src_iso = default_src_iso if len(config.src_isos) > 1: src_iso = "?" trg_iso = default_trg_iso if src_line.startswith("<2"): index = src_line.index(">") val = src_line[2:index] if val != "qaa": trg_iso = val example: lit_types.JsonDict = { "vref": "?", "src_text": decode_sp(src_line), "ref_text": decode_sp(trg_line), "src_iso": src_iso, "trg_iso": trg_iso, } examples.append(example) if len(examples) == 2000: break spec: lit_types.JsonDict = { "vref": lit_types.CategoryLabel(), "src_text": lit_types.TextSegment(), "ref_text": lit_types.TextSegment(), "src_iso": lit_types.CategoryLabel(), "trg_iso": lit_types.CategoryLabel(), } return lit_dataset.Dataset(spec, examples, description="train dataset")
def create_test_dataset(config: Config) -> lit_dataset.Dataset: vref_file_names: List[str] = [] features_file_names: List[str] = [] refs_patterns: List[str] = [] for src_iso in sorted(config.src_isos): prefix = "test" if len(config.src_isos) == 1 else f"test.{src_iso}" features_file_name = f"{prefix}.src.txt" if (config.exp_dir / features_file_name).is_file(): # all target data is stored in a single file vref_file_names.append(f"{prefix}.vref.txt") features_file_names.append(features_file_name) refs_patterns.append(f"{prefix}.trg.detok*.txt") else: # target data is split into separate files for trg_iso in sorted(config.trg_isos): prefix = f"test.{src_iso}.{trg_iso}" vref_file_names.append(f"{prefix}.vref.txt") features_file_names.append(f"{prefix}.src.txt") refs_patterns.append(f"{prefix}.trg.detok*.txt") default_src_iso = config.default_src_iso default_trg_iso = config.default_trg_iso spec = lit_types.JsonDict = { "vref": lit_types.CategoryLabel(), "src_text": lit_types.TextSegment(), "ref_text": lit_types.TextSegment(), "src_iso": lit_types.CategoryLabel(), "trg_iso": lit_types.CategoryLabel(), } examples: List[lit_types.JsonDict] = [] for vref_file_name, features_file_name, refs_pattern in zip( vref_file_names, features_file_names, refs_patterns): src_iso = default_src_iso if features_file_name != "test.src.txt": src_iso = features_file_name.split(".")[1] with (config.exp_dir / features_file_name).open( "r", encoding="utf-8") as src_file, ( config.exp_dir / vref_file_name).open( "r", encoding="utf-8") as vref_file: ref_file_paths = config.exp_dir.glob(refs_pattern) ref_files: List[IO] = [] try: for ref_file_path in ref_file_paths: ref_files.append(ref_file_path.open("r", encoding="utf-8")) for lines in zip(src_file, vref_file, *ref_files): src_line = lines[0].strip() vref_line = lines[1].strip() trg_iso = default_trg_iso if src_line.startswith("<2"): index = src_line.index(">") val = src_line[2:index] if val != "qaa": trg_iso = val example: lit_types.JsonDict = { "vref": vref_line, "src_text": decode_sp(src_line), "src_iso": src_iso, "trg_iso": trg_iso, } for ref_index in range(len(ref_files)): ref_line = lines[ref_index + 2].strip() ref_key = "ref_text" if ref_index == 0 else f"ref_text_{ref_index}" example[ref_key] = ref_line if ref_key not in spec: spec[ref_key] = lit_types.TextSegment() examples.append(example) finally: for ref_file in ref_files: ref_file.close() return lit_dataset.Dataset(spec, examples, description="test dataset")
def test_tcav(self): random.seed(0) # Sets seed since create_comparison_splits() uses random. # Basic test with dummy outputs from the model. examples = [ {'segment': 'a'}, {'segment': 'b'}, {'segment': 'c'}, {'segment': 'd'}, {'segment': 'e'}, {'segment': 'f'}, {'segment': 'g'}, {'segment': 'h'}] indexed_inputs = [ { 'id': '1', 'data': { 'segment': 'a' } }, { 'id': '2', 'data': { 'segment': 'b' } }, { 'id': '3', 'data': { 'segment': 'c' } }, { 'id': '4', 'data': { 'segment': 'd' } }, { 'id': '5', 'data': { 'segment': 'e' } }, { 'id': '6', 'data': { 'segment': 'f' } }, { 'id': '7', 'data': { 'segment': 'g' } }, { 'id': '8', 'data': { 'segment': 'h' } }, { 'id': '9', 'data': { 'segment': 'i' } }, ] model = TestModelClassificationTCAV() dataset_spec = {'segment': lit_types.TextSegment()} dataset = lit_dataset.Dataset(dataset_spec, examples) config = { 'concept_set_ids': ['1', '3', '4', '8'], 'class_to_explain': '1', 'grad_layer': 'cls_grad', 'random_state': 0 } result = self.tcav.run_with_metadata(indexed_inputs, model, dataset, config=config) self.assertLen(result, 1) expected = { 'p_val': 0.0, 'result': { 'score': 1.0, 'cos_sim': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'dot_prods': [ 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, 1.6669444907484283 ], 'accuracy': 0.3333333333333333 } } self.assertDictEqual(expected, result[0])
def test_tcav_sample_from_positive(self): # Tests the case where more concept examples are passed than non-concept # examples, so the concept set is sampled from the concept examples. random.seed(0) # Sets seed since create_comparison_splits() uses random. # Basic test with dummy outputs from the model. examples = [ {'segment': 'a'}, {'segment': 'b'}, {'segment': 'c'}, {'segment': 'd'}, {'segment': 'e'}, {'segment': 'f'}, {'segment': 'g'}, {'segment': 'h'}] indexed_inputs = [ { 'id': '1', 'data': { 'segment': 'a' } }, { 'id': '2', 'data': { 'segment': 'b' } }, { 'id': '3', 'data': { 'segment': 'c' } }, { 'id': '4', 'data': { 'segment': 'd' } }, { 'id': '5', 'data': { 'segment': 'e' } }, { 'id': '6', 'data': { 'segment': 'f' } }, { 'id': '7', 'data': { 'segment': 'g' } }, { 'id': '8', 'data': { 'segment': 'h' } }, ] model = TestModelClassificationTCAV() dataset_spec = {'segment': lit_types.TextSegment()} dataset = lit_dataset.Dataset(dataset_spec, examples) config = { 'concept_set_ids': ['1', '3', '4', '5', '8'], 'class_to_explain': '1', 'grad_layer': 'cls_grad', 'random_state': 0 } result = self.tcav.run_with_metadata(indexed_inputs, model, dataset, config=config) self.assertLen(result, 1) expected = { 'p_val': 0.0, 'result': { 'score': 1.0, 'cos_sim': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'dot_prods': [ 2.0589251447995237e-14, 2.0589251447995237e-14, 2.0589251447995237e-14, 2.0589251447995237e-14, 2.0589251447995237e-14, 2.0589251447995237e-14, 2.0589251447995237e-14, 2.0589251447995237e-14 ], 'accuracy': 0.5 } } self.assertDictEqual(expected, result[0])
def load(self, path: str): datapoints = self.load_datapoints(path) return lit_dataset.Dataset(base=self, examples=datapoints)
def test_all_replacements(self): input_spec = {'text': lit_types.TextSegment()} model = testing_utils.TestRegressionModel(input_spec) # Dataset is only used for spec in word_replacer so define once dataset = lit_dataset.Dataset(input_spec, [{'text': 'blank'}]) ## Test replacements generator = word_replacer.WordReplacer() # Unicode to Unicode input_dict = {'text': '♞ is a black chess knight.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '♞ -> ♟', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': '♟ is a black chess knight.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Unicode to ASCII input_dict = {'text': 'Is répertoire a unicode word?'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'répertoire -> repertoire', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'Is repertoire a unicode word?'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Ignore capitalization input_dict = {'text': 'Capitalization is ignored.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'blank is ignored.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'Capitalization is ignored.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'blank is ignored.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Do not Ignore capitalization input_dict = {'text': 'Capitalization is important.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank', word_replacer.IGNORE_CASING_KEY: False, word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'blank is important.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'Capitalization is important.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank', word_replacer.IGNORE_CASING_KEY: False, word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Repetition input_dict = {'text': 'maybe repetition repetition maybe'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'repetition -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'maybe blank repetition maybe'}, {'text': 'maybe repetition blank maybe'}] self.assertCountEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # No partial match input_dict = {'text': 'A catastrophic storm'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'cat -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Special characters # Punctuation input_dict = {'text': 'A catastrophic storm .'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '. -> -', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A catastrophic storm -'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'A.catastrophic. storm'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '. -> -', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A-catastrophic. storm'}, {'text': 'A.catastrophic- storm'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'A...catastrophic.... storm'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '.. -> --', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A--.catastrophic.... storm'}, {'text': 'A...catastrophic--.. storm'}, {'text': 'A...catastrophic..-- storm'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Underscore input_dict = {'text': 'A catastrophic_storm is raging.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'catastrophic_storm -> nice_storm', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A nice_storm is raging.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Deletion input_dict = {'text': 'A storm is raging.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'storm -> ', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A is raging.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Word next to punctuation and words with punctuation. input_dict = {'text': 'It`s raining cats and dogs.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'dogs -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'It`s raining cats and blank.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Multiple target tokens. input_dict = {'text': 'It`s raining cats and dogs.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'dogs -> horses|donkeys', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'It`s raining cats and horses.'}, {'text': 'It`s raining cats and donkeys.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Test default_replacements applied at init. replacements = {'tree': ['car']} generator = word_replacer.WordReplacer(replacements=replacements) input_dict = {'text': 'black truck hit the tree'} expected = [{'text': 'black truck hit the car'}] config_dict = { word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Test not passing replacements not breaking. generator = word_replacer.WordReplacer() input_dict = {'text': 'xyz yzy zzz.'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset), expected) # Multi word match. input_dict = {'text': 'A red cat is coming.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'red cat -> black dog', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A black dog is coming.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected)
def test_all_replacements(self): input_spec = {'text': lit_types.TextSegment()} model = testing_utils.TestRegressionModel(input_spec) # Dataset is only used for spec in word_replacer so define once dataset = lit_dataset.Dataset(input_spec, {'text': 'blank'}) ## Test replacements generator = word_replacer.WordReplacer() # Unicode to Unicode input_dict = {'text': '♞ is a black chess knight.'} config_dict = {'subs': '♞ -> ♟'} expected = [{'text': '♟ is a black chess knight.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Unicode to ASCII input_dict = {'text': 'Is répertoire a unicode word?'} config_dict = {'subs': 'répertoire -> repertoire'} expected = [{'text': 'Is repertoire a unicode word?'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Capitalization input_dict = {'text': 'Capitalization is important.'} config_dict = {'subs': 'Capitalization -> blank'} expected = [{'text': 'blank is important.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'Capitalization is important.'} config_dict = {'subs': 'capitalization -> blank'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Repetition input_dict = {'text': 'maybe repetition repetition maybe'} config_dict = {'subs': 'repetition -> blank'} expected = [{'text': 'maybe blank repetition maybe'}, {'text': 'maybe repetition blank maybe'}] self.assertCountEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # No partial match input_dict = {'text': 'A catastrophic storm'} config_dict = {'subs': 'cat -> blank'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Special characters # Punctuation input_dict = {'text': 'A catastrophic storm .'} config_dict = {'subs': '. -> -'} expected = [{'text': 'A catastrophic storm -'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Underscore input_dict = {'text': 'A catastrophic_storm is raging.'} config_dict = {'subs': 'catastrophic_storm -> nice_storm'} expected = [{'text': 'A nice_storm is raging.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Word next to punctuation and words with punctuation. input_dict = {'text': 'It`s raining cats and dogs.'} config_dict = {'subs': 'dogs -> blank'} expected = [{'text': 'It`s raining cats and blank.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Test default_replacements applied at init. replacements = {'tree': 'car'} generator = word_replacer.WordReplacer(replacements=replacements) input_dict = {'text': 'black truck hit the tree'} expected = [{'text': 'black truck hit the car'}] self.assertEqual( generator.generate(input_dict, model, dataset), expected) ## Test not passing replacements not breaking. generator = word_replacer.WordReplacer() input_dict = {'text': 'xyz yzy zzz.'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset), expected)
def annotate_generated(datapoints): dataset_to_annotate = lit_dataset.Dataset(base=dataset, examples=datapoints) annotated_dataset = self._run_annotators(dataset_to_annotate) return annotated_dataset.examples
def test_clustering(self): inputs = [ { 'data': { 'segment': 'a b c d' } }, { 'data': { 'segment': 'a b c d' } }, { 'data': { 'segment': 'e f e f' } }, { 'data': { 'segment': 'e f e f' } }, { 'data': { 'segment': 'e f e f' } }, ] model = testing_utils.TestModelClassification() dataset = lit_dataset.Dataset(None, None) config = {'salience_mapper': 'grad-l2', 'n_clusters': 2} model_outputs = [{ 'input_embs_grad': np.array([[1, 1, 1, 1], [0, 1, 0, 1], [1, 1, 1, 1], [1, 1, 1, 1]]), 'tokens': ['a', 'b', 'c', 'd'], 'grad_class': '1' }, { 'input_embs_grad': np.array([[1, 1, 1, 1], [0, 1, 0, 1], [1, 1, 1, 1], [1, 1, 1, 1]]), 'tokens': ['a', 'b', 'c', 'd'], 'grad_class': '1' }, { 'input_embs_grad': np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]), 'tokens': ['e', 'f', 'e', 'f'], 'grad_class': '1' }, { 'input_embs_grad': np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]), 'tokens': ['e', 'f', 'e', 'f'], 'grad_class': '1' }, { 'input_embs_grad': np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]), 'tokens': ['e', 'f', 'e', 'f'], 'grad_class': '1' }] clustering_component = salience_clustering.SalienceClustering( self.salience_mappers) result = clustering_component.run_with_metadata( inputs, model, dataset, model_outputs, config) # Cluster id assignment is random, so in one run the first 2 examples may # be cluster 0, in the next run they may be in cluster 1. cluster_id_of_first = result[ salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'][0] cluster_id_of_last = result[ salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'][-1] np.testing.assert_equal( result[salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'], [ cluster_id_of_first, cluster_id_of_first, cluster_id_of_last, cluster_id_of_last, cluster_id_of_last ]) np.testing.assert_allclose( result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'] [0], result[ salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][1]) np.testing.assert_allclose( result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'] [2], result[ salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][3]) np.testing.assert_allclose( result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'] [2], result[ salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][4]) self.assertIn('input_embs_grad', clustering_component.kmeans) self.assertIsNotNone(clustering_component.kmeans['input_embs_grad'])