コード例 #1
0
    def test_count_vocab_items_correctly_indexes_tags(self):
        tags = ["B", "I", "O", "O", "O"]
        tag_field = TagField(tags, self.text, tag_namespace="tags")

        counter = defaultdict(lambda: defaultdict(int))
        tag_field.count_vocab_items(counter)

        assert counter["tags"]["B"] == 1
        assert counter["tags"]["I"] == 1
        assert counter["tags"]["O"] == 3
        assert set(counter.keys()) == {"tags"}
コード例 #2
0
    def test_count_vocab_items_correctly_indexes_tags(self):
        text = TextField(["here", "are", "some", "words", "."],
                         [token_indexers["single id"]("words")])
        tags = ["B", "I", "O", "O", "O"]
        tag_field = TagField(tags, text, tag_namespace="tags")

        counter = defaultdict(lambda: defaultdict(int))
        tag_field.count_vocab_items(counter)

        assert counter["tags"]["B"] == 1
        assert counter["tags"]["I"] == 1
        assert counter["tags"]["O"] == 3
        assert set(counter.keys()) == {"tags"}
コード例 #3
0
    def test_index_converts_field_correctly(self):
        vocab = Vocabulary()
        b_index = vocab.add_token_to_namespace("B", namespace='*tags')
        i_index = vocab.add_token_to_namespace("I", namespace='*tags')
        o_index = vocab.add_token_to_namespace("O", namespace='*tags')

        tags = ["B", "I", "O", "O", "O"]
        tag_field = TagField(tags, self.text, tag_namespace="*tags")
        tag_field.index(vocab)

        # pylint: disable=protected-access
        assert tag_field._indexed_tags == [b_index, i_index, o_index, o_index, o_index]
        assert tag_field._num_tags == 3
コード例 #4
0
    def read(self, file_path):
        with open(file_path, "r") as data_file:

            instances = []
            logger.info("Reading instances from lines in file at: %s",
                        file_path)
            for line in tqdm.tqdm(data_file):
                line = line.strip("\n")

                # skip blank lines
                if not line:
                    continue

                tokens_and_tags = [
                    pair.rsplit(self._word_tag_delimiter, 1)
                    for pair in line.split(self._token_delimiter)
                ]
                tokens = [x[0] for x in tokens_and_tags]
                tags = [x[1] for x in tokens_and_tags]

                sequence = TextField(tokens, self._token_indexers)
                sequence_tags = TagField(tags, sequence)
                instances.append(
                    Instance({
                        'tokens': sequence,
                        'tags': sequence_tags
                    }))
        if not instances:
            raise ConfigurationError(
                "No instances were read from the given filepath {}. "
                "Is the path correct?".format(file_path))
        return Dataset(instances)
コード例 #5
0
    def _process_sentence(
            self, sentence: List[str], verbal_predicates: List[int],
            predicate_argument_labels: List[List[str]]) -> List[Instance]:
        """
        Parameters
        ----------
        sentence : List[str], required.
            The tokenised sentence.
        verbal_predicates : List[int], required.
            The indexes of the verbal predicates in the
            sentence which have an associated annotation.
        predicate_argument_labels : List[List[str]], required.
            A list of predicate argument labels, one for each verbal_predicate. The
            internal lists are of length: len(sentence).

        Returns
        -------
        A list of Instances.

        """
        sentence_field = TextField(sentence, self._token_indexers)
        if not verbal_predicates:
            # Sentence contains no predicates.
            tags = TagField(["O" for _ in sentence], sentence_field)
            verb_indicator = IndexField(None, sentence_field)
            instance = Instance(
                fields={
                    "tokens": sentence_field,
                    "verb_indicator": verb_indicator,
                    "tags": tags
                })
            return [instance]
        else:
            instances = []
            for verb_index, annotation in zip(verbal_predicates,
                                              predicate_argument_labels):

                tags = TagField(annotation, sentence_field)
                verb_indicator = IndexField(verb_index, sentence_field)
                instances.append(
                    Instance(
                        fields={
                            "tokens": sentence_field,
                            "verb_indicator": verb_indicator,
                            "tags": tags
                        }))
            return instances
コード例 #6
0
    def test_index_converts_field_correctly(self):
        vocab = Vocabulary()
        b_index = vocab.add_token_to_namespace("B", namespace='*tags')
        i_index = vocab.add_token_to_namespace("I", namespace='*tags')
        o_index = vocab.add_token_to_namespace("O", namespace='*tags')

        text = TextField(["here", "are", "some", "words", "."],
                         [token_indexers["single id"]("words")])
        tags = ["B", "I", "O", "O", "O"]
        tag_field = TagField(tags, text, tag_namespace="*tags")
        tag_field.index(vocab)

        # pylint: disable=protected-access
        assert tag_field._indexed_tags == [
            b_index, i_index, o_index, o_index, o_index
        ]
        assert tag_field._num_tags == 3
コード例 #7
0
    def read(self, file_path):
        with open(file_path, "r") as data_file:

            instances = []
            for line in data_file:
                tokens_and_tags = [
                    pair.split("###") for pair in line.strip("\n").split("\t")
                ]
                tokens = [x[0] for x in tokens_and_tags]
                tags = [x[1] for x in tokens_and_tags]

                sequence = TextField(tokens, self._token_indexers)
                sequence_tags = TagField(tags, sequence)
                instances.append(
                    Instance({
                        'tokens': sequence,
                        'tags': sequence_tags
                    }))
        return Dataset(instances)
コード例 #8
0
ファイル: tag_field_test.py プロジェクト: Taekyoon/allennlp
    def test_pad_produces_one_hot_targets(self):
        vocab = Vocabulary()
        vocab.add_token_to_namespace("B", namespace='*tags')
        vocab.add_token_to_namespace("I", namespace='*tags')
        vocab.add_token_to_namespace("O", namespace='*tags')

        tags = ["B", "I", "O", "O", "O"]
        tag_field = TagField(tags, self.text, tag_namespace="*tags")
        tag_field.index(vocab)
        padding_lengths = tag_field.get_padding_lengths()
        array = tag_field.as_array(padding_lengths)
        numpy.testing.assert_array_almost_equal(
            array,
            numpy.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 0,
                                                                      1]]))
コード例 #9
0
    def test_pad_produces_one_hot_targets(self):
        vocab = Vocabulary()
        vocab.add_token_to_namespace("B", namespace='*tags')
        vocab.add_token_to_namespace("I", namespace='*tags')
        vocab.add_token_to_namespace("O", namespace='*tags')

        text = TextField(["here", "are", "some", "words", "."],
                         [token_indexers["single id"]("words")])
        tags = ["B", "I", "O", "O", "O"]
        tag_field = TagField(tags, text, tag_namespace="*tags")
        tag_field.index(vocab)
        padding_lengths = tag_field.get_padding_lengths()
        array = tag_field.pad(padding_lengths)
        numpy.testing.assert_array_almost_equal(
            array,
            numpy.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 0,
                                                                      1]]))
コード例 #10
0
 def test_tag_length_mismatch_raises(self):
     with pytest.raises(ConfigurationError):
         text = TextField(["here", "are", "some", "words", "."], [])
         wrong_tags = ["B", "O", "O"]
         _ = TagField(wrong_tags, text)
コード例 #11
0
 def test_tag_length_mismatch_raises(self):
     with pytest.raises(ConfigurationError):
         wrong_tags = ["B", "O", "O"]
         _ = TagField(wrong_tags, self.text)