def test_enumerate_spans_enumerates_all_spans(self): tokenizer = SpacyWordSplitter(pos_tags=True) sentence = tokenizer.split_words(u"This is a sentence.") spans = span_utils.enumerate_spans(sentence) assert spans == [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 1), (1, 2), (1, 3), (1, 4), (2, 2), (2, 3), (2, 4), (3, 3), (3, 4), (4, 4)] spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2) assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)] spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2, offset=20) assert spans == [(20, 21), (20, 22), (21, 22), (21, 23), (22, 23), (22, 24), (23, 24)] def no_prefixed_punctuation(tokens): # Only include spans which don't start or end with punctuation. return tokens[0].pos_ != u"PUNCT" and tokens[-1].pos_ != u"PUNCT" spans = span_utils.enumerate_spans( sentence, max_span_width=3, min_span_width=2, filter_function=no_prefixed_punctuation) # No longer includes (2, 4) or (3, 4) as these include punctuation # as their last element. assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]
def test_enumerate_spans_enumerates_all_spans(self): tokenizer = SpacyWordSplitter(pos_tags=True) sentence = tokenizer.split_words("This is a sentence.") spans = span_utils.enumerate_spans(sentence) assert spans == [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 1), (1, 2), (1, 3), (1, 4), (2, 2), (2, 3), (2, 4), (3, 3), (3, 4), (4, 4)] spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2) assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)] spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2, offset=20) assert spans == [(20, 21), (20, 22), (21, 22), (21, 23), (22, 23), (22, 24), (23, 24)] def no_prefixed_punctuation(tokens: List[Token]): # Only include spans which don't start or end with punctuation. return tokens[0].pos_ != "PUNCT" and tokens[-1].pos_ != "PUNCT" spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2, filter_function=no_prefixed_punctuation) # No longer includes (2, 4) or (3, 4) as these include punctuation # as their last element. assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]
def test_read_from_file(self): ptb_reader = PennTreeBankConstituencySpanDatasetReader() instances = ptb_reader.read('tests/fixtures/data/example_ptb.trees') assert len(instances) == 2 fields = instances[0].fields tokens = [x.text for x in fields["tokens"].tokens] pos_tags = fields["pos_tags"].labels spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list] span_labels = fields["span_labels"].labels assert tokens == ['Also', ',', 'because', 'UAL', 'Chairman', 'Stephen', 'Wolf', 'and', 'other', 'UAL', 'executives', 'have', 'joined', 'the', 'pilots', "'", 'bid', ',', 'the', 'board', 'might', 'be', 'forced', 'to', 'exclude', 'him', 'from', 'its', 'deliberations', 'in', 'order', 'to', 'be', 'fair', 'to', 'other', 'bidders', '.'] assert pos_tags == ['RB', ',', 'IN', 'NNP', 'NNP', 'NNP', 'NNP', 'CC', 'JJ', 'NNP', 'NNS', 'VBP', 'VBN', 'DT', 'NNS', 'POS', 'NN', ',', 'DT', 'NN', 'MD', 'VB', 'VBN', 'TO', 'VB', 'PRP', 'IN', 'PRP$', 'NNS', 'IN', 'NN', 'TO', 'VB', 'JJ', 'TO', 'JJ', 'NNS', '.'] assert spans == enumerate_spans(tokens) gold_tree = Tree.fromstring("(VROOT(S(ADVP(RB Also))(, ,)(SBAR(IN because)" "(S(NP(NP(NNP UAL)(NNP Chairman)(NNP Stephen)(NNP Wolf))" "(CC and)(NP(JJ other)(NNP UAL)(NNS executives)))(VP(VBP have)" "(VP(VBN joined)(NP(NP(DT the)(NNS pilots)(POS '))(NN bid))))))" "(, ,)(NP(DT the)(NN board))(VP(MD might)(VP(VB be)(VP(VBN " "forced)(S(VP(TO to)(VP(VB exclude)(NP(PRP him))(PP(IN from)" "(NP(PRP$ its)(NNS deliberations)))(SBAR(IN in)(NN order)(S(" "VP(TO to)(VP(VB be)(ADJP(JJ fair)(PP(TO to)(NP(JJ other)(NNS " "bidders))))))))))))))(. .)))") assert fields["metadata"].metadata["gold_tree"] == gold_tree assert fields["metadata"].metadata["tokens"] == tokens correct_spans_and_labels = {} ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels) for span, label in zip(spans, span_labels): if label != "NO-LABEL": assert correct_spans_and_labels[span] == label fields = instances[1].fields tokens = [x.text for x in fields["tokens"].tokens] pos_tags = fields["pos_tags"].labels spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list] span_labels = fields["span_labels"].labels assert tokens == ['That', 'could', 'cost', 'him', 'the', 'chance', 'to', 'influence', 'the', 'outcome', 'and', 'perhaps', 'join', 'the', 'winning', 'bidder', '.'] assert pos_tags == ['DT', 'MD', 'VB', 'PRP', 'DT', 'NN', 'TO', 'VB', 'DT', 'NN', 'CC', 'RB', 'VB', 'DT', 'VBG', 'NN', '.'] assert spans == enumerate_spans(tokens) gold_tree = Tree.fromstring("(VROOT(S(NP(DT That))(VP(MD could)(VP(VB cost)(NP(PRP him))" "(NP(DT the)(NN chance)(S(VP(TO to)(VP(VP(VB influence)(NP(DT the)" "(NN outcome)))(CC and)(VP(ADVP(RB perhaps))(VB join)(NP(DT the)" "(VBG winning)(NN bidder)))))))))(. .)))") assert fields["metadata"].metadata["gold_tree"] == gold_tree assert fields["metadata"].metadata["tokens"] == tokens correct_spans_and_labels = {} ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels) for span, label in zip(spans, span_labels): if label != "NO-LABEL": assert correct_spans_and_labels[span] == label
def text_to_instance( self, # type: ignore tokens, pos_tags=None, gold_tree=None): u""" We take `pre-tokenized` input here, because we don't have a tokenizer in this class. Parameters ---------- tokens : ``List[str]``, required. The tokens in a given sentence. pos_tags ``List[str]``, optional, (default = None). The POS tags for the words in the sentence. gold_tree : ``Tree``, optional (default = None). The gold parse tree to create span labels from. Returns ------- An ``Instance`` containing the following fields: tokens : ``TextField`` The tokens in the sentence. pos_tags : ``SequenceLabelField`` The POS tags of the words in the sentence. Only returned if ``use_pos_tags`` is ``True`` spans : ``ListField[SpanField]`` A ListField containing all possible subspans of the sentence. span_labels : ``SequenceLabelField``, optional. The constiutency tags for each of the possible spans, with respect to a gold parse tree. If a span is not contained within the tree, a span will have a ``NO-LABEL`` label. gold_tree : ``MetadataField(Tree)`` The gold NLTK parse tree for use in evaluation. """ # pylint: disable=arguments-differ text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers) fields = {u"tokens": text_field} if self._use_pos_tags and pos_tags is not None: pos_tag_field = SequenceLabelField(pos_tags, text_field, label_namespace=u"pos") fields[u"pos_tags"] = pos_tag_field elif self._use_pos_tags: raise ConfigurationError( u"use_pos_tags was set to True but no gold pos" u" tags were passed to the dataset reader.") spans = [] gold_labels = [] if gold_tree is not None: gold_spans = {} self._get_gold_spans(gold_tree, 0, gold_spans) else: gold_spans = None for start, end in enumerate_spans(tokens): spans.append(SpanField(start, end, text_field)) if gold_spans is not None: if (start, end) in list(gold_spans.keys()): gold_labels.append(gold_spans[(start, end)]) else: gold_labels.append(u"NO-LABEL") metadata = {u"tokens": tokens} if gold_tree: metadata[u"gold_tree"] = gold_tree if self._use_pos_tags: metadata[u"pos_tags"] = pos_tags fields[u"metadata"] = MetadataField(metadata) span_list_field = ListField(spans) fields[u"spans"] = span_list_field if gold_tree is not None: fields[u"span_labels"] = SequenceLabelField( gold_labels, span_list_field) return Instance(fields)
def text_to_instance( self, # type: ignore tokens: List[str], pos_tags: List[str] = None, gold_tree: Tree = None, ) -> Instance: """ We take `pre-tokenized` input here, because we don't have a tokenizer in this class. Parameters ---------- tokens : ``List[str]``, required. The tokens in a given sentence. pos_tags : ``List[str]``, optional, (default = None). The POS tags for the words in the sentence. gold_tree : ``Tree``, optional (default = None). The gold parse tree to create span labels from. Returns ------- An ``Instance`` containing the following fields: tokens : ``TextField`` The tokens in the sentence. pos_tags : ``SequenceLabelField`` The POS tags of the words in the sentence. Only returned if ``use_pos_tags`` is ``True`` spans : ``ListField[SpanField]`` A ListField containing all possible subspans of the sentence. span_labels : ``SequenceLabelField``, optional. The constituency tags for each of the possible spans, with respect to a gold parse tree. If a span is not contained within the tree, a span will have a ``NO-LABEL`` label. gold_tree : ``MetadataField(Tree)`` The gold NLTK parse tree for use in evaluation. """ if self._convert_parentheses: tokens = [PTB_PARENTHESES.get(token, token) for token in tokens] text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers) fields: Dict[str, Field] = {"tokens": text_field} pos_namespace = self._label_namespace_prefix + self._pos_label_namespace if self._use_pos_tags and pos_tags is not None: pos_tag_field = SequenceLabelField(pos_tags, text_field, label_namespace=pos_namespace) fields["pos_tags"] = pos_tag_field elif self._use_pos_tags: raise ConfigurationError( "use_pos_tags was set to True but no gold pos" " tags were passed to the dataset reader.") spans: List[Field] = [] gold_labels = [] if gold_tree is not None: gold_spans: Dict[Tuple[int, int], str] = {} self._get_gold_spans(gold_tree, 0, gold_spans) else: gold_spans = None for start, end in enumerate_spans(tokens): spans.append(SpanField(start, end, text_field)) if gold_spans is not None: gold_labels.append(gold_spans.get((start, end), "NO-LABEL")) metadata = {"tokens": tokens} if gold_tree: metadata["gold_tree"] = gold_tree if self._use_pos_tags: metadata["pos_tags"] = pos_tags fields["metadata"] = MetadataField(metadata) span_list_field: ListField = ListField(spans) fields["spans"] = span_list_field if gold_tree is not None: fields["span_labels"] = SequenceLabelField( gold_labels, span_list_field, label_namespace=self._label_namespace_prefix + "labels", ) return Instance(fields)
def text_to_instance(self, source_string: str, gold_spans: Dict[Tuple[int, int], str], scene_string: str, answer: str, program: str) -> Instance: # type: ignore """Turns raw source string and target string into an ``Instance``.""" tokens = self.tokenizer.tokenize(source_string) word_pieces = self._get_wordpieces(source_string) word_pieces_tokens = [Token('[CLS]') ] + [Token(wp) for wp in word_pieces] + [Token('[SEP]')] text_field = TextField(tokens, self._token_indexers) wp_field = TextField(word_pieces_tokens, self._token_indexers) fields: Dict[str, Field] = {"tokens": text_field} if gold_spans is None: constants = self._domain_utils.get_constants(program) spans: List[Field] = [] gold_labels = [] for start, end in enumerate_spans(word_pieces): # Shift by 1 due to CLS token spans.append(SpanField(start + 1, end + 1, wp_field)) if gold_spans is not None: # Shift by 1 due to CLS token gold_labels.append( gold_spans.get((start + 1, end + 1), "NO-LABEL")) else: # Create random labels for each span so that labels would be collected. When no # more true labels are left, draw between NO-LABEL and span. These randomly assigned # labels would be ignored during training if constants[0]: gold_labels.append(constants[0].pop()) else: rand_label = np.random.choice(a=["NO-LABEL", "span"], size=1, p=[0.7, 0.3]) gold_labels.append(rand_label[0]) span_list_field: ListField = ListField(spans) fields["spans"] = span_list_field fields["span_labels"] = SequenceLabelField( gold_labels, span_list_field, label_namespace="labels", ) metadata = { "tokens": word_pieces, "scene_str": scene_string, "answer": answer } if program: metadata["program"] = program if gold_spans: metadata["gold_spans"] = gold_spans fields["metadata"] = MetadataField(metadata) return Instance(fields)
def text_to_instance( self, # type: ignore tokens: List[str], verb_label: List[int], tags: List[str] = None, pos_tags: List[str] = None, gold_tree: Tree = None) -> Instance: """ We take `pre-tokenized` input here, because we don't have a tokenizer in this class. Parameters ---------- tokens : ``List[str]``, required. The tokens in a given sentence. verb_label: ``List[int]``, required The verb label should be a one-hot binary vector, the same length as the tokens, indicating the position of the verb to find arguments for. tags: ``List[str]``, , optional (default = None). SRL tags pos_tags ``List[str]``, optional (default = None). The pos tags for the words in the sentence. gold_tree : ``Tree``, optional (default = None). The gold parse tree to create span labels from. Returns ------- An ``Instance`` containing the following fields: tokens : ``TextField`` The tokens in the sentence. pos_tags : ``SequenceLabelField`` The pos tags of the words in the sentence. spans : ``ListField[SpanField]`` A ListField containing all possible subspans of the sentence. span_labels : ``SequenceLabelField``, optional. The constituency tags for each of the possible spans, with respect to a gold parse tree. If a span is not contained within the tree, a span will have a ``NO-LABEL`` label. """ # pylint: disable=arguments-differ fields: Dict[str, Field] = {} text_field = TextField(tokens, token_indexers=self._token_indexers) fields['tokens'] = text_field fields['verb_indicator'] = SequenceLabelField(verb_label, text_field) metadata: Dict[str, Any] = {} if tags: fields['tags'] = SequenceLabelField(tags, text_field) if pos_tags: pos_tag_field = SequenceLabelField(pos_tags, text_field, "pos_tags") fields['pos_tags'] = pos_tag_field metadata['pos_tags'] = True else: pos_tags = ['X' for _ in tokens] fields['pos_tags'] = SequenceLabelField(pos_tags, text_field, "pos_tags") metadata['pos_tags'] = False spans: List[Field] = [] gold_labels = [] if gold_tree is not None: gold_spans_with_pos_tags: Dict[Tuple[int, int], str] = {} self._get_gold_spans(gold_tree, 0, gold_spans_with_pos_tags) gold_spans = { span: label for (span, label) in gold_spans_with_pos_tags.items() if "-POS" not in label } else: gold_spans = None for start, end in enumerate_spans(tokens): spans.append(SpanField(start, end, text_field)) if gold_spans is not None: if (start, end) in gold_spans.keys(): gold_labels.append(gold_spans[(start, end)]) else: gold_labels.append("NO-LABEL") else: gold_labels.append("NO-LABEL") span_list_field: ListField = ListField(spans) fields['spans'] = span_list_field if gold_tree is not None: fields['span_labels'] = SequenceLabelField(gold_labels, span_list_field, "constituent_labels") metadata['span_labels'] = True else: fields['span_labels'] = SequenceLabelField(gold_labels, span_list_field, "constituent_labels") metadata['span_labels'] = False metadata_field = MetadataField(metadata) fields['metadata'] = metadata_field return Instance(fields)
def text_to_instance(self, # type: ignore tokens: List[str], pos_tags: List[str] = None, gold_tree: Tree = None) -> Instance: """ We take `pre-tokenized` input here, because we don't have a tokenizer in this class. Parameters ---------- tokens : ``List[str]``, required. The tokens in a given sentence. pos_tags ``List[str]``, optional, (default = None). The POS tags for the words in the sentence. gold_tree : ``Tree``, optional (default = None). The gold parse tree to create span labels from. Returns ------- An ``Instance`` containing the following fields: tokens : ``TextField`` The tokens in the sentence. pos_tags : ``SequenceLabelField`` The POS tags of the words in the sentence. Only returned if ``use_pos_tags`` is ``True`` spans : ``ListField[SpanField]`` A ListField containing all possible subspans of the sentence. span_labels : ``SequenceLabelField``, optional. The constiutency tags for each of the possible spans, with respect to a gold parse tree. If a span is not contained within the tree, a span will have a ``NO-LABEL`` label. """ # pylint: disable=arguments-differ text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers) fields: Dict[str, Field] = {"tokens": text_field} if self._use_pos_tags and pos_tags is not None: pos_tag_field = SequenceLabelField(pos_tags, text_field, "pos_tags") fields["pos_tags"] = pos_tag_field elif self._use_pos_tags: raise ConfigurationError("use_pos_tags was set to True but no gold pos" " tags were passed to the dataset reader.") spans: List[Field] = [] gold_labels = [] if gold_tree is not None: gold_spans_with_pos_tags: Dict[Tuple[int, int], str] = {} self._get_gold_spans(gold_tree, 0, gold_spans_with_pos_tags) gold_spans = {span: label for (span, label) in gold_spans_with_pos_tags.items() if "-POS" not in label} else: gold_spans = None for start, end in enumerate_spans(tokens): spans.append(SpanField(start, end, text_field)) if gold_spans is not None: if (start, end) in gold_spans.keys(): gold_labels.append(gold_spans[(start, end)]) else: gold_labels.append("NO-LABEL") span_list_field: ListField = ListField(spans) fields["spans"] = span_list_field if gold_tree is not None: fields["span_labels"] = SequenceLabelField(gold_labels, span_list_field) return Instance(fields)
def test_read_from_file(self): ptb_reader = PennTreeBankConstituencySpanDatasetReader() instances = ptb_reader.read( str(FIXTURES_ROOT / "structured_prediction" / "example_ptb.trees")) assert len(instances) == 2 fields = instances[0].fields tokens = [x.text for x in fields["tokens"].tokens] pos_tags = fields["pos_tags"].labels spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list] span_labels = fields["span_labels"].labels assert tokens == [ "Also", ",", "because", "UAL", "Chairman", "Stephen", "Wolf", "and", "other", "UAL", "executives", "have", "joined", "the", "pilots", "'", "bid", ",", "the", "board", "might", "be", "forced", "to", "exclude", "him", "from", "its", "deliberations", "in", "order", "to", "be", "fair", "to", "other", "bidders", ".", ] assert pos_tags == [ "RB", ",", "IN", "NNP", "NNP", "NNP", "NNP", "CC", "JJ", "NNP", "NNS", "VBP", "VBN", "DT", "NNS", "POS", "NN", ",", "DT", "NN", "MD", "VB", "VBN", "TO", "VB", "PRP", "IN", "PRP$", "NNS", "IN", "NN", "TO", "VB", "JJ", "TO", "JJ", "NNS", ".", ] assert spans == enumerate_spans(tokens) gold_tree = Tree.fromstring( "(S(ADVP(RB Also))(, ,)(SBAR(IN because)" "(S(NP(NP(NNP UAL)(NNP Chairman)(NNP Stephen)(NNP Wolf))" "(CC and)(NP(JJ other)(NNP UAL)(NNS executives)))(VP(VBP have)" "(VP(VBN joined)(NP(NP(DT the)(NNS pilots)(POS '))(NN bid))))))" "(, ,)(NP(DT the)(NN board))(VP(MD might)(VP(VB be)(VP(VBN " "forced)(S(VP(TO to)(VP(VB exclude)(NP(PRP him))(PP(IN from)" "(NP(PRP$ its)(NNS deliberations)))(SBAR(IN in)(NN order)(S(" "VP(TO to)(VP(VB be)(ADJP(JJ fair)(PP(TO to)(NP(JJ other)(NNS " "bidders))))))))))))))(. .))") assert fields["metadata"].metadata["gold_tree"] == gold_tree assert fields["metadata"].metadata["tokens"] == tokens correct_spans_and_labels = {} ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels) for span, label in zip(spans, span_labels): if label != "NO-LABEL": assert correct_spans_and_labels[span] == label fields = instances[1].fields tokens = [x.text for x in fields["tokens"].tokens] pos_tags = fields["pos_tags"].labels spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list] span_labels = fields["span_labels"].labels assert tokens == [ "That", "could", "cost", "him", "the", "chance", "to", "influence", "the", "outcome", "and", "perhaps", "join", "the", "winning", "bidder", ".", ] assert pos_tags == [ "DT", "MD", "VB", "PRP", "DT", "NN", "TO", "VB", "DT", "NN", "CC", "RB", "VB", "DT", "VBG", "NN", ".", ] assert spans == enumerate_spans(tokens) gold_tree = Tree.fromstring( "(S(NP(DT That))(VP(MD could)(VP(VB cost)(NP(PRP him))" "(NP(DT the)(NN chance)(S(VP(TO to)(VP(VP(VB influence)(NP(DT the)" "(NN outcome)))(CC and)(VP(ADVP(RB perhaps))(VB join)(NP(DT the)" "(VBG winning)(NN bidder)))))))))(. .))") assert fields["metadata"].metadata["gold_tree"] == gold_tree assert fields["metadata"].metadata["tokens"] == tokens correct_spans_and_labels = {} ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels) for span, label in zip(spans, span_labels): if label != "NO-LABEL": assert correct_spans_and_labels[span] == label
def text_to_instance(self, # type: ignore tokens: List[str], pos_tags: List[str] = None, gold_tree: Tree = None) -> Instance: """ We take `pre-tokenized` input here, because we don't have a tokenizer in this class. Parameters ---------- tokens : ``List[str]``, required. The tokens in a given sentence. pos_tags ``List[str]``, optional, (default = None). The POS tags for the words in the sentence. gold_tree : ``Tree``, optional (default = None). The gold parse tree to create span labels from. Returns ------- An ``Instance`` containing the following fields: tokens : ``TextField`` The tokens in the sentence. pos_tags : ``SequenceLabelField`` The POS tags of the words in the sentence. Only returned if ``use_pos_tags`` is ``True`` spans : ``ListField[SpanField]`` A ListField containing all possible subspans of the sentence. span_labels : ``SequenceLabelField``, optional. The constiutency tags for each of the possible spans, with respect to a gold parse tree. If a span is not contained within the tree, a span will have a ``NO-LABEL`` label. gold_tree : ``MetadataField(Tree)`` The gold NLTK parse tree for use in evaluation. """ # pylint: disable=arguments-differ text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers) fields: Dict[str, Field] = {"tokens": text_field} pos_namespace = self._label_namespace_prefix + self._pos_label_namespace if self._use_pos_tags and pos_tags is not None: pos_tag_field = SequenceLabelField(pos_tags, text_field, label_namespace=pos_namespace) fields["pos_tags"] = pos_tag_field elif self._use_pos_tags: raise ConfigurationError("use_pos_tags was set to True but no gold pos" " tags were passed to the dataset reader.") spans: List[Field] = [] gold_labels = [] if gold_tree is not None: gold_spans: Dict[Tuple[int, int], str] = {} self._get_gold_spans(gold_tree, 0, gold_spans) else: gold_spans = None for start, end in enumerate_spans(tokens): spans.append(SpanField(start, end, text_field)) if gold_spans is not None: if (start, end) in gold_spans.keys(): gold_labels.append(gold_spans[(start, end)]) else: gold_labels.append("NO-LABEL") metadata = {"tokens": tokens} if gold_tree: metadata["gold_tree"] = gold_tree if self._use_pos_tags: metadata["pos_tags"] = pos_tags fields["metadata"] = MetadataField(metadata) span_list_field: ListField = ListField(spans) fields["spans"] = span_list_field if gold_tree is not None: fields["span_labels"] = SequenceLabelField(gold_labels, span_list_field, label_namespace=self._label_namespace_prefix + "labels") return Instance(fields)
def test_read_from_file(self): ptb_reader = PennTreeBankConstituencySpanDatasetReader() instances = ptb_reader.read(str(self.FIXTURES_ROOT / 'data' / 'example_ptb.trees')) assert len(instances) == 2 fields = instances[0].fields tokens = [x.text for x in fields["tokens"].tokens] pos_tags = fields["pos_tags"].labels spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list] span_labels = fields["span_labels"].labels assert tokens == ['Also', ',', 'because', 'UAL', 'Chairman', 'Stephen', 'Wolf', 'and', 'other', 'UAL', 'executives', 'have', 'joined', 'the', 'pilots', "'", 'bid', ',', 'the', 'board', 'might', 'be', 'forced', 'to', 'exclude', 'him', 'from', 'its', 'deliberations', 'in', 'order', 'to', 'be', 'fair', 'to', 'other', 'bidders', '.'] assert pos_tags == ['RB', ',', 'IN', 'NNP', 'NNP', 'NNP', 'NNP', 'CC', 'JJ', 'NNP', 'NNS', 'VBP', 'VBN', 'DT', 'NNS', 'POS', 'NN', ',', 'DT', 'NN', 'MD', 'VB', 'VBN', 'TO', 'VB', 'PRP', 'IN', 'PRP$', 'NNS', 'IN', 'NN', 'TO', 'VB', 'JJ', 'TO', 'JJ', 'NNS', '.'] assert spans == enumerate_spans(tokens) gold_tree = Tree.fromstring("(S(ADVP(RB Also))(, ,)(SBAR(IN because)" "(S(NP(NP(NNP UAL)(NNP Chairman)(NNP Stephen)(NNP Wolf))" "(CC and)(NP(JJ other)(NNP UAL)(NNS executives)))(VP(VBP have)" "(VP(VBN joined)(NP(NP(DT the)(NNS pilots)(POS '))(NN bid))))))" "(, ,)(NP(DT the)(NN board))(VP(MD might)(VP(VB be)(VP(VBN " "forced)(S(VP(TO to)(VP(VB exclude)(NP(PRP him))(PP(IN from)" "(NP(PRP$ its)(NNS deliberations)))(SBAR(IN in)(NN order)(S(" "VP(TO to)(VP(VB be)(ADJP(JJ fair)(PP(TO to)(NP(JJ other)(NNS " "bidders))))))))))))))(. .))") assert fields["metadata"].metadata["gold_tree"] == gold_tree assert fields["metadata"].metadata["tokens"] == tokens correct_spans_and_labels = {} ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels) for span, label in zip(spans, span_labels): if label != "NO-LABEL": assert correct_spans_and_labels[span] == label fields = instances[1].fields tokens = [x.text for x in fields["tokens"].tokens] pos_tags = fields["pos_tags"].labels spans = [(x.span_start, x.span_end) for x in fields["spans"].field_list] span_labels = fields["span_labels"].labels assert tokens == ['That', 'could', 'cost', 'him', 'the', 'chance', 'to', 'influence', 'the', 'outcome', 'and', 'perhaps', 'join', 'the', 'winning', 'bidder', '.'] assert pos_tags == ['DT', 'MD', 'VB', 'PRP', 'DT', 'NN', 'TO', 'VB', 'DT', 'NN', 'CC', 'RB', 'VB', 'DT', 'VBG', 'NN', '.'] assert spans == enumerate_spans(tokens) gold_tree = Tree.fromstring("(S(NP(DT That))(VP(MD could)(VP(VB cost)(NP(PRP him))" "(NP(DT the)(NN chance)(S(VP(TO to)(VP(VP(VB influence)(NP(DT the)" "(NN outcome)))(CC and)(VP(ADVP(RB perhaps))(VB join)(NP(DT the)" "(VBG winning)(NN bidder)))))))))(. .))") assert fields["metadata"].metadata["gold_tree"] == gold_tree assert fields["metadata"].metadata["tokens"] == tokens correct_spans_and_labels = {} ptb_reader._get_gold_spans(gold_tree, 0, correct_spans_and_labels) for span, label in zip(spans, span_labels): if label != "NO-LABEL": assert correct_spans_and_labels[span] == label
from allennlp.modules.token_embedders import Embedding # Create an instance with multiple spans tokens = ['I', 'shot', 'an', 'elephant', 'in', 'my', 'pajamas', '.'] tokens = [Token(token) for token in tokens] token_indexers = {'tokens': SingleIdTokenIndexer()} text_field = TextField(tokens, token_indexers=token_indexers) spans = [(2, 3), (5, 6)] # ('an', 'elephant') and ('my', 'pajamas) span_fields = ListField( [SpanField(start, end, text_field) for start, end in spans]) instance = Instance({'tokens': text_field, 'spans': span_fields}) # Alternatively, you can also enumerate all spans spans = enumerate_spans(tokens, max_span_width=3) print('all spans up to length 3:') print(spans) def filter_function(span_tokens): return not any(t == Token('.') for t in span_tokens) spans = enumerate_spans(tokens, max_span_width=3, filter_function=filter_function) print('all spans up to length 3, excluding punctuation:') print(spans) # Index and convert to tensors