예제 #1
0
파일: annotators.py 프로젝트: PAIR-code/lit
    def annotate(self,
                 inputs: List[JsonDict],
                 dataset: lit_dataset.Dataset,
                 dataset_spec_to_annotate: Optional[types.Spec] = None):
        if len(self._annotator_model.input_spec().items()) != 1:
            raise ValueError(
                'Annotator model provided to PerFieldAnnotator does not '
                'operate on a single field')

        datasets = {}
        for input_name, input_type in self._annotator_model.input_spec().items(
        ):
            # Do remap of inputs based on input name needed by annotator.
            ds_keys = utils.find_spec_keys(dataset.spec(), type(input_type))
            for ds_key in ds_keys:
                temp_ds = lit_dataset.Dataset(examples=inputs, base=dataset)
                datasets[ds_key] = temp_ds.remap({ds_key: input_name})

        for ds_key, ds in datasets.items():
            outputs = self._annotator_model.predict(ds.examples)
            for output_name, output_type in self._annotator_model.output_spec(
            ).items():
                # Update dataset spec with new annotated field.
                field_name = f'{self._name}:{output_name}:{ds_key}'
                if dataset_spec_to_annotate:
                    dataset_spec_to_annotate[field_name] = attr.evolve(
                        output_type, annotated=True)

                # Update all examples with annotator output.
                for example, output in zip(inputs, outputs):
                    example[field_name] = output[output_name]
예제 #2
0
파일: app.py 프로젝트: PAIR-code/lit
 def _run_annotators(self,
                     dataset: lit_dataset.Dataset) -> lit_dataset.Dataset:
     datapoints = [dict(ex) for ex in dataset.examples]
     annotated_spec = dict(dataset.spec())
     for annotator in self._annotators:
         annotator.annotate(datapoints, dataset, annotated_spec)
     return lit_dataset.Dataset(base=dataset,
                                examples=datapoints,
                                spec=annotated_spec)
예제 #3
0
def symmetrize_edges(dataset: lit_dataset.Dataset) -> lit_dataset.Dataset:
    """Symmetrize edges by adding copies with span1 and span2 interchanged."""
    def _swap(edge):
        return lit_dtypes.EdgeLabel(edge.span2, edge.span1, edge.label)

    edge_fields = utils.find_spec_keys(dataset.spec(), lit_types.EdgeLabels)
    examples = []
    for ex in dataset.examples:
        new_ex = copy.copy(ex)
        for field in edge_fields:
            new_ex[field] += [_swap(edge) for edge in ex[field]]
        examples.append(new_ex)
    return lit_dataset.Dataset(dataset.spec(), examples)
예제 #4
0
    def test_interpreter(self):
        interpreter = image_gradient_maps.VanillaGradients()
        model = ClassificationTestModel()
        self.assertTrue(interpreter.is_compatible(model))

        pil_image = PILImage.new(mode='RGB', size=(300, 200))
        inp = {'image': image_utils.convert_pil_to_image_str(pil_image)}
        run_output = interpreter.run(inputs=[inp],
                                     model=model,
                                     dataset=lit_dataset.Dataset())[0]

        self.assertIn('grads', run_output)
        self.assertIsInstance(run_output['grads'], str)
예제 #5
0
    def test_gradient_maps(self):
        self.ig = gradient_maps.IntegratedGradients()

        # Basic test with dummy outputs from the model.
        inputs = [{'segment': '_'}]
        model = testing_utils.TestModelClassification()
        dataset = lit_dataset.Dataset(None, None)
        output = self.ig.run(inputs, model, dataset)

        self.assertLen(output, 1)

        salience = output[0]['input_embs_grad'].salience
        target = np.array([0.25, 0.25, 0.25, 0.25])
        self.assertTrue((salience == target).all())
예제 #6
0
파일: app.py 프로젝트: PAIR-code/lit
    def _annotate_new_data(self,
                           data,
                           dataset_name: Optional[Text] = None,
                           **unused_kw) -> List[IndexedInput]:
        """Fill in index and other extra data for the provided datapoints."""
        # TODO(lit-dev): unify this with hash fn on dataset objects.
        assert dataset_name is not None, 'No dataset specified.'

        # Generate annotated versions of new datapoints.
        dataset = self._datasets[dataset_name]
        input_examples = [example['data'] for example in data['inputs']]
        dataset_to_annotate = lit_dataset.Dataset(base=dataset,
                                                  examples=input_examples)
        annotated_dataset = self._run_annotators(dataset_to_annotate)

        # Add annotations and IDs to new datapoints.
        for i, example in enumerate(data['inputs']):
            example['data'] = annotated_dataset.examples[i]
            example['id'] = caching.input_hash(example['data'])

        return data['inputs']
예제 #7
0
    def test_regression(self):
        interpreter = image_gradient_maps.GuidedIG()
        model = RegressionTestModel()
        self.assertTrue(interpreter.is_compatible(model))

        input_image_array = np.zeros(shape=[20, 15, 3], dtype=np.uint8)
        input_image_array[0, 0, 0] = 10
        input_image_array[1, 0, 0] = 20

        pil_image = PILImage.fromarray(input_image_array, mode='RGB')

        inp = {'image': image_utils.convert_pil_to_image_str(pil_image)}
        run_output = interpreter.run(inputs=[inp],
                                     model=model,
                                     dataset=lit_dataset.Dataset())[0]
        self.assertIn('grads', run_output)
        overlay_str = run_output['grads']
        overlay_bytes = image_utils.convert_image_str_to_array(
            overlay_str, shape=RegressionTestModel.GRADIENT_SHAPE)
        self.assertIsNotNone(overlay_bytes)
        self.assertSequenceEqual(overlay_bytes.shape,
                                 RegressionTestModel.GRADIENT_SHAPE)
예제 #8
0
 def test_remap(self):
     """Test remap method."""
     spec = {
         "score": types.Scalar(),
         "text": types.TextSegment(),
     }
     datapoints = [
         {
             "score": 0,
             "text": "a"
         },
         {
             "score": 0,
             "text": "b"
         },
     ]
     dset = lit_dataset.Dataset(spec, datapoints)
     remap_dict = {"score": "val", "nothing": "nada"}
     remapped_dset = dset.remap(remap_dict)
     self.assertIn("val", remapped_dset.spec())
     self.assertNotIn("score", remapped_dset.spec())
     self.assertEqual({"val": 0, "text": "a"}, remapped_dset.examples[0])
예제 #9
0
파일: analyze.py 프로젝트: sillsdev/silnlp
def create_train_dataset(config: Config) -> lit_dataset.Dataset:
    src_path = config.exp_dir / "train.src.txt"
    trg_path = config.exp_dir / "train.trg.txt"
    default_src_iso = config.default_src_iso
    default_trg_iso = config.default_trg_iso
    examples: List[lit_types.JsonDict] = []
    with src_path.open("r", encoding="utf-8") as src_file, open(
            trg_path, "r", encoding="utf-8") as trg_file:
        for src_line, trg_line in zip(src_file, trg_file):
            src_line = src_line.strip()
            trg_line = trg_line.strip()
            src_iso = default_src_iso
            if len(config.src_isos) > 1:
                src_iso = "?"
            trg_iso = default_trg_iso
            if src_line.startswith("<2"):
                index = src_line.index(">")
                val = src_line[2:index]
                if val != "qaa":
                    trg_iso = val
            example: lit_types.JsonDict = {
                "vref": "?",
                "src_text": decode_sp(src_line),
                "ref_text": decode_sp(trg_line),
                "src_iso": src_iso,
                "trg_iso": trg_iso,
            }
            examples.append(example)
            if len(examples) == 2000:
                break
    spec: lit_types.JsonDict = {
        "vref": lit_types.CategoryLabel(),
        "src_text": lit_types.TextSegment(),
        "ref_text": lit_types.TextSegment(),
        "src_iso": lit_types.CategoryLabel(),
        "trg_iso": lit_types.CategoryLabel(),
    }
    return lit_dataset.Dataset(spec, examples, description="train dataset")
예제 #10
0
파일: analyze.py 프로젝트: sillsdev/silnlp
def create_test_dataset(config: Config) -> lit_dataset.Dataset:
    vref_file_names: List[str] = []
    features_file_names: List[str] = []
    refs_patterns: List[str] = []
    for src_iso in sorted(config.src_isos):
        prefix = "test" if len(config.src_isos) == 1 else f"test.{src_iso}"
        features_file_name = f"{prefix}.src.txt"
        if (config.exp_dir / features_file_name).is_file():
            # all target data is stored in a single file
            vref_file_names.append(f"{prefix}.vref.txt")
            features_file_names.append(features_file_name)
            refs_patterns.append(f"{prefix}.trg.detok*.txt")
        else:
            # target data is split into separate files
            for trg_iso in sorted(config.trg_isos):
                prefix = f"test.{src_iso}.{trg_iso}"
                vref_file_names.append(f"{prefix}.vref.txt")
                features_file_names.append(f"{prefix}.src.txt")
                refs_patterns.append(f"{prefix}.trg.detok*.txt")

    default_src_iso = config.default_src_iso
    default_trg_iso = config.default_trg_iso
    spec = lit_types.JsonDict = {
        "vref": lit_types.CategoryLabel(),
        "src_text": lit_types.TextSegment(),
        "ref_text": lit_types.TextSegment(),
        "src_iso": lit_types.CategoryLabel(),
        "trg_iso": lit_types.CategoryLabel(),
    }
    examples: List[lit_types.JsonDict] = []
    for vref_file_name, features_file_name, refs_pattern in zip(
            vref_file_names, features_file_names, refs_patterns):
        src_iso = default_src_iso
        if features_file_name != "test.src.txt":
            src_iso = features_file_name.split(".")[1]

        with (config.exp_dir / features_file_name).open(
                "r", encoding="utf-8") as src_file, (
                    config.exp_dir / vref_file_name).open(
                        "r", encoding="utf-8") as vref_file:
            ref_file_paths = config.exp_dir.glob(refs_pattern)
            ref_files: List[IO] = []
            try:
                for ref_file_path in ref_file_paths:
                    ref_files.append(ref_file_path.open("r", encoding="utf-8"))
                for lines in zip(src_file, vref_file, *ref_files):
                    src_line = lines[0].strip()
                    vref_line = lines[1].strip()
                    trg_iso = default_trg_iso
                    if src_line.startswith("<2"):
                        index = src_line.index(">")
                        val = src_line[2:index]
                        if val != "qaa":
                            trg_iso = val
                    example: lit_types.JsonDict = {
                        "vref": vref_line,
                        "src_text": decode_sp(src_line),
                        "src_iso": src_iso,
                        "trg_iso": trg_iso,
                    }
                    for ref_index in range(len(ref_files)):
                        ref_line = lines[ref_index + 2].strip()
                        ref_key = "ref_text" if ref_index == 0 else f"ref_text_{ref_index}"
                        example[ref_key] = ref_line
                        if ref_key not in spec:
                            spec[ref_key] = lit_types.TextSegment()
                    examples.append(example)
            finally:
                for ref_file in ref_files:
                    ref_file.close()

    return lit_dataset.Dataset(spec, examples, description="test dataset")
예제 #11
0
  def test_tcav(self):
    random.seed(0)  # Sets seed since create_comparison_splits() uses random.

    # Basic test with dummy outputs from the model.
    examples = [
        {'segment': 'a'},
        {'segment': 'b'},
        {'segment': 'c'},
        {'segment': 'd'},
        {'segment': 'e'},
        {'segment': 'f'},
        {'segment': 'g'},
        {'segment': 'h'}]
    indexed_inputs = [
        {
            'id': '1',
            'data': {
                'segment': 'a'
            }
        },
        {
            'id': '2',
            'data': {
                'segment': 'b'
            }
        },
        {
            'id': '3',
            'data': {
                'segment': 'c'
            }
        },
        {
            'id': '4',
            'data': {
                'segment': 'd'
            }
        },
        {
            'id': '5',
            'data': {
                'segment': 'e'
            }
        },
        {
            'id': '6',
            'data': {
                'segment': 'f'
            }
        },
        {
            'id': '7',
            'data': {
                'segment': 'g'
            }
        },
        {
            'id': '8',
            'data': {
                'segment': 'h'
            }
        },
        {
            'id': '9',
            'data': {
                'segment': 'i'
            }
        },
    ]
    model = TestModelClassificationTCAV()
    dataset_spec = {'segment': lit_types.TextSegment()}
    dataset = lit_dataset.Dataset(dataset_spec, examples)
    config = {
        'concept_set_ids': ['1', '3', '4', '8'],
        'class_to_explain': '1',
        'grad_layer': 'cls_grad',
        'random_state': 0
    }
    result = self.tcav.run_with_metadata(indexed_inputs, model, dataset,
                                         config=config)

    self.assertLen(result, 1)
    expected = {
        'p_val': 0.0,
        'result': {
            'score': 1.0,
            'cos_sim': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
            'dot_prods': [
                1.6669444907484283, 1.6669444907484283, 1.6669444907484283,
                1.6669444907484283, 1.6669444907484283, 1.6669444907484283,
                1.6669444907484283, 1.6669444907484283, 1.6669444907484283
            ],
            'accuracy': 0.3333333333333333
        }
    }
    self.assertDictEqual(expected, result[0])
예제 #12
0
  def test_tcav_sample_from_positive(self):
    # Tests the case where more concept examples are passed than non-concept
    # examples, so the concept set is sampled from the concept examples.

    random.seed(0)  # Sets seed since create_comparison_splits() uses random.

    # Basic test with dummy outputs from the model.
    examples = [
        {'segment': 'a'},
        {'segment': 'b'},
        {'segment': 'c'},
        {'segment': 'd'},
        {'segment': 'e'},
        {'segment': 'f'},
        {'segment': 'g'},
        {'segment': 'h'}]
    indexed_inputs = [
        {
            'id': '1',
            'data': {
                'segment': 'a'
            }
        },
        {
            'id': '2',
            'data': {
                'segment': 'b'
            }
        },
        {
            'id': '3',
            'data': {
                'segment': 'c'
            }
        },
        {
            'id': '4',
            'data': {
                'segment': 'd'
            }
        },
        {
            'id': '5',
            'data': {
                'segment': 'e'
            }
        },
        {
            'id': '6',
            'data': {
                'segment': 'f'
            }
        },
        {
            'id': '7',
            'data': {
                'segment': 'g'
            }
        },
        {
            'id': '8',
            'data': {
                'segment': 'h'
            }
        },
    ]
    model = TestModelClassificationTCAV()
    dataset_spec = {'segment': lit_types.TextSegment()}
    dataset = lit_dataset.Dataset(dataset_spec, examples)
    config = {
        'concept_set_ids': ['1', '3', '4', '5', '8'],
        'class_to_explain': '1',
        'grad_layer': 'cls_grad',
        'random_state': 0
    }
    result = self.tcav.run_with_metadata(indexed_inputs, model, dataset,
                                         config=config)

    self.assertLen(result, 1)
    expected = {
        'p_val': 0.0,
        'result': {
            'score': 1.0,
            'cos_sim': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
            'dot_prods': [
                2.0589251447995237e-14, 2.0589251447995237e-14,
                2.0589251447995237e-14, 2.0589251447995237e-14,
                2.0589251447995237e-14, 2.0589251447995237e-14,
                2.0589251447995237e-14, 2.0589251447995237e-14
            ],
            'accuracy': 0.5
        }
    }
    self.assertDictEqual(expected, result[0])
예제 #13
0
 def load(self, path: str):
     datapoints = self.load_datapoints(path)
     return lit_dataset.Dataset(base=self, examples=datapoints)
예제 #14
0
  def test_all_replacements(self):
    input_spec = {'text': lit_types.TextSegment()}
    model = testing_utils.TestRegressionModel(input_spec)
    # Dataset is only used for spec in word_replacer so define once
    dataset = lit_dataset.Dataset(input_spec, [{'text': 'blank'}])

    ## Test replacements
    generator = word_replacer.WordReplacer()
    # Unicode to Unicode
    input_dict = {'text': '♞ is a black chess knight.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '♞ -> ♟',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': '♟ is a black chess knight.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Unicode to ASCII
    input_dict = {'text': 'Is répertoire a unicode word?'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'répertoire -> repertoire',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'Is repertoire a unicode word?'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Ignore capitalization
    input_dict = {'text': 'Capitalization is ignored.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'blank is ignored.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'Capitalization is ignored.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'blank is ignored.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Do not Ignore capitalization
    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank',
        word_replacer.IGNORE_CASING_KEY: False,
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'blank is important.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank',
        word_replacer.IGNORE_CASING_KEY: False,
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Repetition
    input_dict = {'text': 'maybe repetition repetition maybe'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'repetition -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'maybe blank repetition maybe'},
                {'text': 'maybe repetition blank maybe'}]
    self.assertCountEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # No partial match
    input_dict = {'text': 'A catastrophic storm'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'cat -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Special characters
    # Punctuation
    input_dict = {'text': 'A catastrophic storm .'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '. -> -',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A catastrophic storm -'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'A.catastrophic. storm'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '. -> -',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A-catastrophic. storm'},
                {'text': 'A.catastrophic- storm'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'A...catastrophic.... storm'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '.. -> --',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A--.catastrophic.... storm'},
                {'text': 'A...catastrophic--.. storm'},
                {'text': 'A...catastrophic..-- storm'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Underscore
    input_dict = {'text': 'A catastrophic_storm is raging.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'catastrophic_storm -> nice_storm',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A nice_storm is raging.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Deletion
    input_dict = {'text': 'A storm is raging.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'storm -> ',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A  is raging.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Word next to punctuation and words with punctuation.
    input_dict = {'text': 'It`s raining cats and dogs.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'dogs -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'It`s raining cats and blank.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Multiple target tokens.
    input_dict = {'text': 'It`s raining cats and dogs.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'dogs -> horses|donkeys',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'It`s raining cats and horses.'},
                {'text': 'It`s raining cats and donkeys.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Test default_replacements applied at init.
    replacements = {'tree': ['car']}
    generator = word_replacer.WordReplacer(replacements=replacements)
    input_dict = {'text': 'black truck hit the tree'}
    expected = [{'text': 'black truck hit the car'}]
    config_dict = {
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }

    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Test not passing replacements not breaking.
    generator = word_replacer.WordReplacer()
    input_dict = {'text': 'xyz yzy zzz.'}
    expected = []

    self.assertEqual(
        generator.generate(input_dict, model, dataset), expected)

    # Multi word match.
    input_dict = {'text': 'A red cat is coming.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'red cat -> black dog',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A black dog is coming.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)
예제 #15
0
  def test_all_replacements(self):
    input_spec = {'text': lit_types.TextSegment()}
    model = testing_utils.TestRegressionModel(input_spec)
    # Dataset is only used for spec in word_replacer so define once
    dataset = lit_dataset.Dataset(input_spec, {'text': 'blank'})

    ## Test replacements
    generator = word_replacer.WordReplacer()
    # Unicode to Unicode
    input_dict = {'text': '♞ is a black chess knight.'}
    config_dict = {'subs': '♞ -> ♟'}
    expected = [{'text': '♟ is a black chess knight.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Unicode to ASCII
    input_dict = {'text': 'Is répertoire a unicode word?'}
    config_dict = {'subs': 'répertoire -> repertoire'}
    expected = [{'text': 'Is repertoire a unicode word?'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Capitalization
    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {'subs': 'Capitalization -> blank'}
    expected = [{'text': 'blank is important.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {'subs': 'capitalization -> blank'}
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Repetition
    input_dict = {'text': 'maybe repetition repetition maybe'}
    config_dict = {'subs': 'repetition -> blank'}
    expected = [{'text': 'maybe blank repetition maybe'},
                {'text': 'maybe repetition blank maybe'}]
    self.assertCountEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # No partial match
    input_dict = {'text': 'A catastrophic storm'}
    config_dict = {'subs': 'cat -> blank'}
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Special characters
    # Punctuation
    input_dict = {'text': 'A catastrophic storm .'}
    config_dict = {'subs': '. -> -'}
    expected = [{'text': 'A catastrophic storm -'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Underscore
    input_dict = {'text': 'A catastrophic_storm is raging.'}
    config_dict = {'subs': 'catastrophic_storm -> nice_storm'}
    expected = [{'text': 'A nice_storm is raging.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Word next to punctuation and words with punctuation.
    input_dict = {'text': 'It`s raining cats and dogs.'}
    config_dict = {'subs': 'dogs -> blank'}
    expected = [{'text': 'It`s raining cats and blank.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Test default_replacements applied at init.
    replacements = {'tree': 'car'}
    generator = word_replacer.WordReplacer(replacements=replacements)
    input_dict = {'text': 'black truck hit the tree'}
    expected = [{'text': 'black truck hit the car'}]

    self.assertEqual(
        generator.generate(input_dict, model, dataset), expected)

    ## Test not passing replacements not breaking.
    generator = word_replacer.WordReplacer()
    input_dict = {'text': 'xyz yzy zzz.'}
    expected = []

    self.assertEqual(
        generator.generate(input_dict, model, dataset), expected)
예제 #16
0
파일: app.py 프로젝트: PAIR-code/lit
 def annotate_generated(datapoints):
     dataset_to_annotate = lit_dataset.Dataset(base=dataset,
                                               examples=datapoints)
     annotated_dataset = self._run_annotators(dataset_to_annotate)
     return annotated_dataset.examples
예제 #17
0
    def test_clustering(self):
        inputs = [
            {
                'data': {
                    'segment': 'a b c d'
                }
            },
            {
                'data': {
                    'segment': 'a b c d'
                }
            },
            {
                'data': {
                    'segment': 'e f e f'
                }
            },
            {
                'data': {
                    'segment': 'e f e f'
                }
            },
            {
                'data': {
                    'segment': 'e f e f'
                }
            },
        ]
        model = testing_utils.TestModelClassification()
        dataset = lit_dataset.Dataset(None, None)
        config = {'salience_mapper': 'grad-l2', 'n_clusters': 2}

        model_outputs = [{
            'input_embs_grad':
            np.array([[1, 1, 1, 1], [0, 1, 0, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
            'tokens': ['a', 'b', 'c', 'd'],
            'grad_class':
            '1'
        }, {
            'input_embs_grad':
            np.array([[1, 1, 1, 1], [0, 1, 0, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
            'tokens': ['a', 'b', 'c', 'd'],
            'grad_class':
            '1'
        }, {
            'input_embs_grad':
            np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
            'tokens': ['e', 'f', 'e', 'f'],
            'grad_class':
            '1'
        }, {
            'input_embs_grad':
            np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
            'tokens': ['e', 'f', 'e', 'f'],
            'grad_class':
            '1'
        }, {
            'input_embs_grad':
            np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
            'tokens': ['e', 'f', 'e', 'f'],
            'grad_class':
            '1'
        }]

        clustering_component = salience_clustering.SalienceClustering(
            self.salience_mappers)
        result = clustering_component.run_with_metadata(
            inputs, model, dataset, model_outputs, config)
        # Cluster id assignment is random, so in one run the first 2 examples may
        # be cluster 0, in the next run they may be in cluster 1.
        cluster_id_of_first = result[
            salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'][0]
        cluster_id_of_last = result[
            salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'][-1]
        np.testing.assert_equal(
            result[salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'], [
                cluster_id_of_first, cluster_id_of_first, cluster_id_of_last,
                cluster_id_of_last, cluster_id_of_last
            ])
        np.testing.assert_allclose(
            result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad']
            [0], result[
                salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][1])
        np.testing.assert_allclose(
            result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad']
            [2], result[
                salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][3])
        np.testing.assert_allclose(
            result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad']
            [2], result[
                salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][4])
        self.assertIn('input_embs_grad', clustering_component.kmeans)
        self.assertIsNotNone(clustering_component.kmeans['input_embs_grad'])