def test_enumerate_spans_enumerates_all_spans(self):
        tokenizer = SpacyWordSplitter(pos_tags=True)
        sentence = tokenizer.split_words(u"This is a sentence.")

        spans = span_utils.enumerate_spans(sentence)
        assert spans == [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 1),
                         (1, 2), (1, 3), (1, 4), (2, 2), (2, 3), (2, 4),
                         (3, 3), (3, 4), (4, 4)]

        spans = span_utils.enumerate_spans(sentence,
                                           max_span_width=3,
                                           min_span_width=2)
        assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 4),
                         (3, 4)]

        spans = span_utils.enumerate_spans(sentence,
                                           max_span_width=3,
                                           min_span_width=2,
                                           offset=20)
        assert spans == [(20, 21), (20, 22), (21, 22), (21, 23), (22, 23),
                         (22, 24), (23, 24)]

        def no_prefixed_punctuation(tokens):
            # Only include spans which don't start or end with punctuation.
            return tokens[0].pos_ != u"PUNCT" and tokens[-1].pos_ != u"PUNCT"

        spans = span_utils.enumerate_spans(
            sentence,
            max_span_width=3,
            min_span_width=2,
            filter_function=no_prefixed_punctuation)

        # No longer includes (2, 4) or (3, 4) as these include punctuation
        # as their last element.
        assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]
    def test_enumerate_spans_enumerates_all_spans(self):
        tokenizer = SpacyWordSplitter(pos_tags=True)
        sentence = tokenizer.split_words("This is a sentence.")

        spans = span_utils.enumerate_spans(sentence)
        assert spans == [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 1), (1, 2),
                         (1, 3), (1, 4), (2, 2), (2, 3), (2, 4), (3, 3), (3, 4), (4, 4)]

        spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2)
        assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)]

        spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2, offset=20)
        assert spans == [(20, 21), (20, 22), (21, 22), (21, 23), (22, 23), (22, 24), (23, 24)]

        def no_prefixed_punctuation(tokens: List[Token]):
            # Only include spans which don't start or end with punctuation.
            return tokens[0].pos_ != "PUNCT" and tokens[-1].pos_ != "PUNCT"

        spans = span_utils.enumerate_spans(sentence,
                                           max_span_width=3,
                                           min_span_width=2,
                                           filter_function=no_prefixed_punctuation)

        # No longer includes (2, 4) or (3, 4) as these include punctuation
        # as their last element.
        assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]
    def test_read_from_file(self):

        ptb_reader = PennTreeBankConstituencySpanDatasetReader()
        instances = ptb_reader.read('tests/fixtures/data/example_ptb.trees')

        assert len(instances) == 2

        fields = instances[0].fields
        tokens = [x.text for x in fields["tokens"].tokens]
        pos_tags = fields["pos_tags"].labels
        spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list]
        span_labels = fields["span_labels"].labels

        assert tokens == ['Also', ',', 'because', 'UAL', 'Chairman', 'Stephen', 'Wolf',
                          'and', 'other', 'UAL', 'executives', 'have', 'joined', 'the',
                          'pilots', "'", 'bid', ',', 'the', 'board', 'might', 'be', 'forced',
                          'to', 'exclude', 'him', 'from', 'its', 'deliberations', 'in',
                          'order', 'to', 'be', 'fair', 'to', 'other', 'bidders', '.']
        assert pos_tags == ['RB', ',', 'IN', 'NNP', 'NNP', 'NNP', 'NNP', 'CC', 'JJ', 'NNP',
                            'NNS', 'VBP', 'VBN', 'DT', 'NNS', 'POS', 'NN', ',', 'DT', 'NN',
                            'MD', 'VB', 'VBN', 'TO', 'VB', 'PRP', 'IN', 'PRP$',
                            'NNS', 'IN', 'NN', 'TO', 'VB', 'JJ', 'TO', 'JJ', 'NNS', '.']

        assert spans == enumerate_spans(tokens)
        gold_tree = Tree.fromstring("(VROOT(S(ADVP(RB Also))(, ,)(SBAR(IN because)"
                                    "(S(NP(NP(NNP UAL)(NNP Chairman)(NNP Stephen)(NNP Wolf))"
                                    "(CC and)(NP(JJ other)(NNP UAL)(NNS executives)))(VP(VBP have)"
                                    "(VP(VBN joined)(NP(NP(DT the)(NNS pilots)(POS '))(NN bid))))))"
                                    "(, ,)(NP(DT the)(NN board))(VP(MD might)(VP(VB be)(VP(VBN "
                                    "forced)(S(VP(TO to)(VP(VB exclude)(NP(PRP him))(PP(IN from)"
                                    "(NP(PRP$ its)(NNS deliberations)))(SBAR(IN in)(NN order)(S("
                                    "VP(TO to)(VP(VB be)(ADJP(JJ fair)(PP(TO to)(NP(JJ other)(NNS "
                                    "bidders))))))))))))))(. .)))")

        assert fields["metadata"].metadata["gold_tree"] == gold_tree
        assert fields["metadata"].metadata["tokens"] == tokens

        correct_spans_and_labels = {}
        ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels)
        for span, label in zip(spans, span_labels):
            if label != "NO-LABEL":
                assert correct_spans_and_labels[span] == label


        fields = instances[1].fields
        tokens = [x.text for x in fields["tokens"].tokens]
        pos_tags = fields["pos_tags"].labels
        spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list]
        span_labels = fields["span_labels"].labels

        assert tokens == ['That', 'could', 'cost', 'him', 'the', 'chance',
                          'to', 'influence', 'the', 'outcome', 'and', 'perhaps',
                          'join', 'the', 'winning', 'bidder', '.']

        assert pos_tags == ['DT', 'MD', 'VB', 'PRP', 'DT', 'NN',
                            'TO', 'VB', 'DT', 'NN', 'CC', 'RB', 'VB', 'DT',
                            'VBG', 'NN', '.']

        assert spans == enumerate_spans(tokens)

        gold_tree = Tree.fromstring("(VROOT(S(NP(DT That))(VP(MD could)(VP(VB cost)(NP(PRP him))"
                                    "(NP(DT the)(NN chance)(S(VP(TO to)(VP(VP(VB influence)(NP(DT the)"
                                    "(NN outcome)))(CC and)(VP(ADVP(RB perhaps))(VB join)(NP(DT the)"
                                    "(VBG winning)(NN bidder)))))))))(. .)))")

        assert fields["metadata"].metadata["gold_tree"] == gold_tree
        assert fields["metadata"].metadata["tokens"] == tokens

        correct_spans_and_labels = {}
        ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels)
        for span, label in zip(spans, span_labels):
            if label != "NO-LABEL":
                assert correct_spans_and_labels[span] == label
Exemple #4
0
    def text_to_instance(
            self,  # type: ignore
            tokens,
            pos_tags=None,
            gold_tree=None):
        u"""
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            The tokens in a given sentence.
        pos_tags ``List[str]``, optional, (default = None).
            The POS tags for the words in the sentence.
        gold_tree : ``Tree``, optional (default = None).
            The gold parse tree to create span labels from.

        Returns
        -------
        An ``Instance`` containing the following fields:
            tokens : ``TextField``
                The tokens in the sentence.
            pos_tags : ``SequenceLabelField``
                The POS tags of the words in the sentence.
                Only returned if ``use_pos_tags`` is ``True``
            spans : ``ListField[SpanField]``
                A ListField containing all possible subspans of the
                sentence.
            span_labels : ``SequenceLabelField``, optional.
                The constiutency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a ``NO-LABEL`` label.
            gold_tree : ``MetadataField(Tree)``
                The gold NLTK parse tree for use in evaluation.
        """
        # pylint: disable=arguments-differ
        text_field = TextField([Token(x) for x in tokens],
                               token_indexers=self._token_indexers)
        fields = {u"tokens": text_field}

        if self._use_pos_tags and pos_tags is not None:
            pos_tag_field = SequenceLabelField(pos_tags,
                                               text_field,
                                               label_namespace=u"pos")
            fields[u"pos_tags"] = pos_tag_field
        elif self._use_pos_tags:
            raise ConfigurationError(
                u"use_pos_tags was set to True but no gold pos"
                u" tags were passed to the dataset reader.")
        spans = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans = {}
            self._get_gold_spans(gold_tree, 0, gold_spans)

        else:
            gold_spans = None
        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                if (start, end) in list(gold_spans.keys()):
                    gold_labels.append(gold_spans[(start, end)])
                else:
                    gold_labels.append(u"NO-LABEL")

        metadata = {u"tokens": tokens}
        if gold_tree:
            metadata[u"gold_tree"] = gold_tree
        if self._use_pos_tags:
            metadata[u"pos_tags"] = pos_tags

        fields[u"metadata"] = MetadataField(metadata)

        span_list_field = ListField(spans)
        fields[u"spans"] = span_list_field
        if gold_tree is not None:
            fields[u"span_labels"] = SequenceLabelField(
                gold_labels, span_list_field)
        return Instance(fields)
    def text_to_instance(
        self,  # type: ignore
        tokens: List[str],
        pos_tags: List[str] = None,
        gold_tree: Tree = None,
    ) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            The tokens in a given sentence.
        pos_tags : ``List[str]``, optional, (default = None).
            The POS tags for the words in the sentence.
        gold_tree : ``Tree``, optional (default = None).
            The gold parse tree to create span labels from.

        Returns
        -------
        An ``Instance`` containing the following fields:
            tokens : ``TextField``
                The tokens in the sentence.
            pos_tags : ``SequenceLabelField``
                The POS tags of the words in the sentence.
                Only returned if ``use_pos_tags`` is ``True``
            spans : ``ListField[SpanField]``
                A ListField containing all possible subspans of the
                sentence.
            span_labels : ``SequenceLabelField``, optional.
                The constituency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a ``NO-LABEL`` label.
            gold_tree : ``MetadataField(Tree)``
                The gold NLTK parse tree for use in evaluation.
        """

        if self._convert_parentheses:
            tokens = [PTB_PARENTHESES.get(token, token) for token in tokens]
        text_field = TextField([Token(x) for x in tokens],
                               token_indexers=self._token_indexers)
        fields: Dict[str, Field] = {"tokens": text_field}

        pos_namespace = self._label_namespace_prefix + self._pos_label_namespace
        if self._use_pos_tags and pos_tags is not None:
            pos_tag_field = SequenceLabelField(pos_tags,
                                               text_field,
                                               label_namespace=pos_namespace)
            fields["pos_tags"] = pos_tag_field
        elif self._use_pos_tags:
            raise ConfigurationError(
                "use_pos_tags was set to True but no gold pos"
                " tags were passed to the dataset reader.")
        spans: List[Field] = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans: Dict[Tuple[int, int], str] = {}
            self._get_gold_spans(gold_tree, 0, gold_spans)

        else:
            gold_spans = None
        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                gold_labels.append(gold_spans.get((start, end), "NO-LABEL"))

        metadata = {"tokens": tokens}
        if gold_tree:
            metadata["gold_tree"] = gold_tree
        if self._use_pos_tags:
            metadata["pos_tags"] = pos_tags

        fields["metadata"] = MetadataField(metadata)

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field
        if gold_tree is not None:
            fields["span_labels"] = SequenceLabelField(
                gold_labels,
                span_list_field,
                label_namespace=self._label_namespace_prefix + "labels",
            )
        return Instance(fields)
    def text_to_instance(self, source_string: str,
                         gold_spans: Dict[Tuple[int, int],
                                          str], scene_string: str, answer: str,
                         program: str) -> Instance:  # type: ignore
        """Turns raw source string and target string into an ``Instance``."""
        tokens = self.tokenizer.tokenize(source_string)
        word_pieces = self._get_wordpieces(source_string)
        word_pieces_tokens = [Token('[CLS]')
                              ] + [Token(wp)
                                   for wp in word_pieces] + [Token('[SEP]')]

        text_field = TextField(tokens, self._token_indexers)
        wp_field = TextField(word_pieces_tokens, self._token_indexers)
        fields: Dict[str, Field] = {"tokens": text_field}

        if gold_spans is None:
            constants = self._domain_utils.get_constants(program)

        spans: List[Field] = []
        gold_labels = []

        for start, end in enumerate_spans(word_pieces):
            # Shift by 1 due to CLS token
            spans.append(SpanField(start + 1, end + 1, wp_field))

            if gold_spans is not None:
                # Shift by 1 due to CLS token
                gold_labels.append(
                    gold_spans.get((start + 1, end + 1), "NO-LABEL"))
            else:
                # Create random labels for each span so that labels would be collected. When no
                # more true labels are left, draw between NO-LABEL and span. These randomly assigned
                # labels would be ignored during training
                if constants[0]:
                    gold_labels.append(constants[0].pop())
                else:
                    rand_label = np.random.choice(a=["NO-LABEL", "span"],
                                                  size=1,
                                                  p=[0.7, 0.3])
                    gold_labels.append(rand_label[0])

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field

        fields["span_labels"] = SequenceLabelField(
            gold_labels,
            span_list_field,
            label_namespace="labels",
        )

        metadata = {
            "tokens": word_pieces,
            "scene_str": scene_string,
            "answer": answer
        }
        if program:
            metadata["program"] = program
        if gold_spans:
            metadata["gold_spans"] = gold_spans

        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
Exemple #7
0
    def text_to_instance(
            self,  # type: ignore
            tokens: List[str],
            verb_label: List[int],
            tags: List[str] = None,
            pos_tags: List[str] = None,
            gold_tree: Tree = None) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            The tokens in a given sentence.
        verb_label: ``List[int]``, required
            The verb label should be a one-hot binary vector,
            the same length as the tokens, indicating the position of the verb to find arguments for.
        tags: ``List[str]``, , optional (default = None).
            SRL tags
        pos_tags ``List[str]``, optional (default = None).
            The pos tags for the words in the sentence.
        gold_tree : ``Tree``, optional (default = None).
            The gold parse tree to create span labels from.

        Returns
        -------
        An ``Instance`` containing the following fields:
            tokens : ``TextField``
                The tokens in the sentence.
            pos_tags : ``SequenceLabelField``
                The pos tags of the words in the sentence.
            spans : ``ListField[SpanField]``
                A ListField containing all possible subspans of the
                sentence.
            span_labels : ``SequenceLabelField``, optional.
                The constituency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a ``NO-LABEL`` label.
        """
        # pylint: disable=arguments-differ

        fields: Dict[str, Field] = {}
        text_field = TextField(tokens, token_indexers=self._token_indexers)
        fields['tokens'] = text_field
        fields['verb_indicator'] = SequenceLabelField(verb_label, text_field)

        metadata: Dict[str, Any] = {}

        if tags:
            fields['tags'] = SequenceLabelField(tags, text_field)

        if pos_tags:
            pos_tag_field = SequenceLabelField(pos_tags, text_field,
                                               "pos_tags")
            fields['pos_tags'] = pos_tag_field
            metadata['pos_tags'] = True
        else:
            pos_tags = ['X' for _ in tokens]
            fields['pos_tags'] = SequenceLabelField(pos_tags, text_field,
                                                    "pos_tags")
            metadata['pos_tags'] = False

        spans: List[Field] = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans_with_pos_tags: Dict[Tuple[int, int], str] = {}
            self._get_gold_spans(gold_tree, 0, gold_spans_with_pos_tags)
            gold_spans = {
                span: label
                for (span, label) in gold_spans_with_pos_tags.items()
                if "-POS" not in label
            }
        else:
            gold_spans = None

        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                if (start, end) in gold_spans.keys():
                    gold_labels.append(gold_spans[(start, end)])
                else:
                    gold_labels.append("NO-LABEL")
            else:
                gold_labels.append("NO-LABEL")

        span_list_field: ListField = ListField(spans)
        fields['spans'] = span_list_field

        if gold_tree is not None:
            fields['span_labels'] = SequenceLabelField(gold_labels,
                                                       span_list_field,
                                                       "constituent_labels")
            metadata['span_labels'] = True
        else:
            fields['span_labels'] = SequenceLabelField(gold_labels,
                                                       span_list_field,
                                                       "constituent_labels")
            metadata['span_labels'] = False

        metadata_field = MetadataField(metadata)
        fields['metadata'] = metadata_field

        return Instance(fields)
Exemple #8
0
    def text_to_instance(self, # type: ignore
                         tokens: List[str],
                         pos_tags: List[str] = None,
                         gold_tree: Tree = None) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            The tokens in a given sentence.
        pos_tags ``List[str]``, optional, (default = None).
            The POS tags for the words in the sentence.
        gold_tree : ``Tree``, optional (default = None).
            The gold parse tree to create span labels from.

        Returns
        -------
        An ``Instance`` containing the following fields:
            tokens : ``TextField``
                The tokens in the sentence.
            pos_tags : ``SequenceLabelField``
                The POS tags of the words in the sentence.
                Only returned if ``use_pos_tags`` is ``True``
            spans : ``ListField[SpanField]``
                A ListField containing all possible subspans of the
                sentence.
            span_labels : ``SequenceLabelField``, optional.
                The constiutency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a ``NO-LABEL`` label.
        """
        # pylint: disable=arguments-differ
        text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers)
        fields: Dict[str, Field] = {"tokens": text_field}

        if self._use_pos_tags and pos_tags is not None:
            pos_tag_field = SequenceLabelField(pos_tags, text_field, "pos_tags")
            fields["pos_tags"] = pos_tag_field
        elif self._use_pos_tags:
            raise ConfigurationError("use_pos_tags was set to True but no gold pos"
                                     " tags were passed to the dataset reader.")
        spans: List[Field] = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans_with_pos_tags: Dict[Tuple[int, int], str] = {}
            self._get_gold_spans(gold_tree, 0, gold_spans_with_pos_tags)
            gold_spans = {span: label for (span, label)
                          in gold_spans_with_pos_tags.items() if "-POS" not in label}
        else:
            gold_spans = None
        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                if (start, end) in gold_spans.keys():
                    gold_labels.append(gold_spans[(start, end)])
                else:
                    gold_labels.append("NO-LABEL")

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field
        if gold_tree is not None:
            fields["span_labels"] = SequenceLabelField(gold_labels, span_list_field)

        return Instance(fields)
Exemple #9
0
    def test_read_from_file(self):

        ptb_reader = PennTreeBankConstituencySpanDatasetReader()
        instances = ptb_reader.read(
            str(FIXTURES_ROOT / "structured_prediction" / "example_ptb.trees"))

        assert len(instances) == 2

        fields = instances[0].fields
        tokens = [x.text for x in fields["tokens"].tokens]
        pos_tags = fields["pos_tags"].labels
        spans = [(x.span_start, x.span_end)
                 for x in fields["spans"].field_list]
        span_labels = fields["span_labels"].labels

        assert tokens == [
            "Also",
            ",",
            "because",
            "UAL",
            "Chairman",
            "Stephen",
            "Wolf",
            "and",
            "other",
            "UAL",
            "executives",
            "have",
            "joined",
            "the",
            "pilots",
            "'",
            "bid",
            ",",
            "the",
            "board",
            "might",
            "be",
            "forced",
            "to",
            "exclude",
            "him",
            "from",
            "its",
            "deliberations",
            "in",
            "order",
            "to",
            "be",
            "fair",
            "to",
            "other",
            "bidders",
            ".",
        ]
        assert pos_tags == [
            "RB",
            ",",
            "IN",
            "NNP",
            "NNP",
            "NNP",
            "NNP",
            "CC",
            "JJ",
            "NNP",
            "NNS",
            "VBP",
            "VBN",
            "DT",
            "NNS",
            "POS",
            "NN",
            ",",
            "DT",
            "NN",
            "MD",
            "VB",
            "VBN",
            "TO",
            "VB",
            "PRP",
            "IN",
            "PRP$",
            "NNS",
            "IN",
            "NN",
            "TO",
            "VB",
            "JJ",
            "TO",
            "JJ",
            "NNS",
            ".",
        ]

        assert spans == enumerate_spans(tokens)
        gold_tree = Tree.fromstring(
            "(S(ADVP(RB Also))(, ,)(SBAR(IN because)"
            "(S(NP(NP(NNP UAL)(NNP Chairman)(NNP Stephen)(NNP Wolf))"
            "(CC and)(NP(JJ other)(NNP UAL)(NNS executives)))(VP(VBP have)"
            "(VP(VBN joined)(NP(NP(DT the)(NNS pilots)(POS '))(NN bid))))))"
            "(, ,)(NP(DT the)(NN board))(VP(MD might)(VP(VB be)(VP(VBN "
            "forced)(S(VP(TO to)(VP(VB exclude)(NP(PRP him))(PP(IN from)"
            "(NP(PRP$ its)(NNS deliberations)))(SBAR(IN in)(NN order)(S("
            "VP(TO to)(VP(VB be)(ADJP(JJ fair)(PP(TO to)(NP(JJ other)(NNS "
            "bidders))))))))))))))(. .))")

        assert fields["metadata"].metadata["gold_tree"] == gold_tree
        assert fields["metadata"].metadata["tokens"] == tokens

        correct_spans_and_labels = {}
        ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels)
        for span, label in zip(spans, span_labels):
            if label != "NO-LABEL":
                assert correct_spans_and_labels[span] == label

        fields = instances[1].fields
        tokens = [x.text for x in fields["tokens"].tokens]
        pos_tags = fields["pos_tags"].labels
        spans = [(x.span_start, x.span_end)
                 for x in fields["spans"].field_list]
        span_labels = fields["span_labels"].labels

        assert tokens == [
            "That",
            "could",
            "cost",
            "him",
            "the",
            "chance",
            "to",
            "influence",
            "the",
            "outcome",
            "and",
            "perhaps",
            "join",
            "the",
            "winning",
            "bidder",
            ".",
        ]

        assert pos_tags == [
            "DT",
            "MD",
            "VB",
            "PRP",
            "DT",
            "NN",
            "TO",
            "VB",
            "DT",
            "NN",
            "CC",
            "RB",
            "VB",
            "DT",
            "VBG",
            "NN",
            ".",
        ]

        assert spans == enumerate_spans(tokens)

        gold_tree = Tree.fromstring(
            "(S(NP(DT That))(VP(MD could)(VP(VB cost)(NP(PRP him))"
            "(NP(DT the)(NN chance)(S(VP(TO to)(VP(VP(VB influence)(NP(DT the)"
            "(NN outcome)))(CC and)(VP(ADVP(RB perhaps))(VB join)(NP(DT the)"
            "(VBG winning)(NN bidder)))))))))(. .))")

        assert fields["metadata"].metadata["gold_tree"] == gold_tree
        assert fields["metadata"].metadata["tokens"] == tokens

        correct_spans_and_labels = {}
        ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels)
        for span, label in zip(spans, span_labels):
            if label != "NO-LABEL":
                assert correct_spans_and_labels[span] == label
Exemple #10
0
    def text_to_instance(self, # type: ignore
                         tokens: List[str],
                         pos_tags: List[str] = None,
                         gold_tree: Tree = None) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            The tokens in a given sentence.
        pos_tags ``List[str]``, optional, (default = None).
            The POS tags for the words in the sentence.
        gold_tree : ``Tree``, optional (default = None).
            The gold parse tree to create span labels from.

        Returns
        -------
        An ``Instance`` containing the following fields:
            tokens : ``TextField``
                The tokens in the sentence.
            pos_tags : ``SequenceLabelField``
                The POS tags of the words in the sentence.
                Only returned if ``use_pos_tags`` is ``True``
            spans : ``ListField[SpanField]``
                A ListField containing all possible subspans of the
                sentence.
            span_labels : ``SequenceLabelField``, optional.
                The constiutency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a ``NO-LABEL`` label.
            gold_tree : ``MetadataField(Tree)``
                The gold NLTK parse tree for use in evaluation.
        """
        # pylint: disable=arguments-differ
        text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers)
        fields: Dict[str, Field] = {"tokens": text_field}

        pos_namespace = self._label_namespace_prefix + self._pos_label_namespace
        if self._use_pos_tags and pos_tags is not None:
            pos_tag_field = SequenceLabelField(pos_tags, text_field,
                                               label_namespace=pos_namespace)
            fields["pos_tags"] = pos_tag_field
        elif self._use_pos_tags:
            raise ConfigurationError("use_pos_tags was set to True but no gold pos"
                                     " tags were passed to the dataset reader.")
        spans: List[Field] = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans: Dict[Tuple[int, int], str] = {}
            self._get_gold_spans(gold_tree, 0, gold_spans)

        else:
            gold_spans = None
        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                if (start, end) in gold_spans.keys():
                    gold_labels.append(gold_spans[(start, end)])
                else:
                    gold_labels.append("NO-LABEL")

        metadata = {"tokens": tokens}
        if gold_tree:
            metadata["gold_tree"] = gold_tree
        if self._use_pos_tags:
            metadata["pos_tags"] = pos_tags

        fields["metadata"] = MetadataField(metadata)

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field
        if gold_tree is not None:
            fields["span_labels"] = SequenceLabelField(gold_labels,
                                                       span_list_field,
                                                       label_namespace=self._label_namespace_prefix + "labels")
        return Instance(fields)
    def test_read_from_file(self):

        ptb_reader = PennTreeBankConstituencySpanDatasetReader()
        instances = ptb_reader.read(str(self.FIXTURES_ROOT / 'data' / 'example_ptb.trees'))

        assert len(instances) == 2

        fields = instances[0].fields
        tokens = [x.text for x in fields["tokens"].tokens]
        pos_tags = fields["pos_tags"].labels
        spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list]
        span_labels = fields["span_labels"].labels

        assert tokens == ['Also', ',', 'because', 'UAL', 'Chairman', 'Stephen', 'Wolf',
                          'and', 'other', 'UAL', 'executives', 'have', 'joined', 'the',
                          'pilots', "'", 'bid', ',', 'the', 'board', 'might', 'be', 'forced',
                          'to', 'exclude', 'him', 'from', 'its', 'deliberations', 'in',
                          'order', 'to', 'be', 'fair', 'to', 'other', 'bidders', '.']
        assert pos_tags == ['RB', ',', 'IN', 'NNP', 'NNP', 'NNP', 'NNP', 'CC', 'JJ', 'NNP',
                            'NNS', 'VBP', 'VBN', 'DT', 'NNS', 'POS', 'NN', ',', 'DT', 'NN',
                            'MD', 'VB', 'VBN', 'TO', 'VB', 'PRP', 'IN', 'PRP$',
                            'NNS', 'IN', 'NN', 'TO', 'VB', 'JJ', 'TO', 'JJ', 'NNS', '.']

        assert spans == enumerate_spans(tokens)
        gold_tree = Tree.fromstring("(S(ADVP(RB Also))(, ,)(SBAR(IN because)"
                                    "(S(NP(NP(NNP UAL)(NNP Chairman)(NNP Stephen)(NNP Wolf))"
                                    "(CC and)(NP(JJ other)(NNP UAL)(NNS executives)))(VP(VBP have)"
                                    "(VP(VBN joined)(NP(NP(DT the)(NNS pilots)(POS '))(NN bid))))))"
                                    "(, ,)(NP(DT the)(NN board))(VP(MD might)(VP(VB be)(VP(VBN "
                                    "forced)(S(VP(TO to)(VP(VB exclude)(NP(PRP him))(PP(IN from)"
                                    "(NP(PRP$ its)(NNS deliberations)))(SBAR(IN in)(NN order)(S("
                                    "VP(TO to)(VP(VB be)(ADJP(JJ fair)(PP(TO to)(NP(JJ other)(NNS "
                                    "bidders))))))))))))))(. .))")

        assert fields["metadata"].metadata["gold_tree"] == gold_tree
        assert fields["metadata"].metadata["tokens"] == tokens

        correct_spans_and_labels = {}
        ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels)
        for span, label in zip(spans, span_labels):
            if label != "NO-LABEL":
                assert correct_spans_and_labels[span] == label


        fields = instances[1].fields
        tokens = [x.text for x in fields["tokens"].tokens]
        pos_tags = fields["pos_tags"].labels
        spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list]
        span_labels = fields["span_labels"].labels

        assert tokens == ['That', 'could', 'cost', 'him', 'the', 'chance',
                          'to', 'influence', 'the', 'outcome', 'and', 'perhaps',
                          'join', 'the', 'winning', 'bidder', '.']

        assert pos_tags == ['DT', 'MD', 'VB', 'PRP', 'DT', 'NN',
                            'TO', 'VB', 'DT', 'NN', 'CC', 'RB', 'VB', 'DT',
                            'VBG', 'NN', '.']

        assert spans == enumerate_spans(tokens)

        gold_tree = Tree.fromstring("(S(NP(DT That))(VP(MD could)(VP(VB cost)(NP(PRP him))"
                                    "(NP(DT the)(NN chance)(S(VP(TO to)(VP(VP(VB influence)(NP(DT the)"
                                    "(NN outcome)))(CC and)(VP(ADVP(RB perhaps))(VB join)(NP(DT the)"
                                    "(VBG winning)(NN bidder)))))))))(. .))")

        assert fields["metadata"].metadata["gold_tree"] == gold_tree
        assert fields["metadata"].metadata["tokens"] == tokens

        correct_spans_and_labels = {}
        ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels)
        for span, label in zip(spans, span_labels):
            if label != "NO-LABEL":
                assert correct_spans_and_labels[span] == label
from allennlp.modules.token_embedders import Embedding

# Create an instance with multiple spans
tokens = ['I', 'shot', 'an', 'elephant', 'in', 'my', 'pajamas', '.']
tokens = [Token(token) for token in tokens]
token_indexers = {'tokens': SingleIdTokenIndexer()}
text_field = TextField(tokens, token_indexers=token_indexers)

spans = [(2, 3), (5, 6)]  # ('an', 'elephant') and ('my', 'pajamas)
span_fields = ListField(
    [SpanField(start, end, text_field) for start, end in spans])

instance = Instance({'tokens': text_field, 'spans': span_fields})

# Alternatively, you can also enumerate all spans
spans = enumerate_spans(tokens, max_span_width=3)
print('all spans up to length 3:')
print(spans)


def filter_function(span_tokens):
    return not any(t == Token('.') for t in span_tokens)


spans = enumerate_spans(tokens,
                        max_span_width=3,
                        filter_function=filter_function)
print('all spans up to length 3, excluding punctuation:')
print(spans)

# Index and convert to tensors