def from_tree(proto_tree): """ Convert a FlattenedParseTree back into a Tree returns Tree, score (score might be None if it is missing) """ score = None stack = deque() for node in proto_tree.nodes: if node.HasField("score") and score is None: score = node.score if node.openNode: if len(stack) > 0 and isinstance( stack[-1], FlattenedParseTree.Node) and stack[-1].openNode: raise ValueError( "Got a proto with no label on a node: {}".format( proto_tree)) stack.append(node) continue if not node.closeNode: child = Tree(label=node.value) # TODO: do something with the score stack.append(child) continue # must be a close operation... if len(stack) <= 1: raise ValueError( "Got a proto with too many close operations: {}".format( proto_tree)) # on a close operation, pop until we hit the open # then turn everything in that span into a new node children = [] nextNode = stack.pop() while not isinstance(nextNode, FlattenedParseTree.Node): children.append(nextNode) nextNode = stack.pop() if len(children) == 0: raise ValueError( "Got a proto with an open immediately followed by a close: {}". format(proto_tree)) children.reverse() label = children[0] children = children[1:] subtree = Tree(label=label.label, children=children) stack.append(subtree) if len(stack) > 1: raise ValueError( "Got a proto which does not close all of the nodes: {}".format( proto_tree)) tree = stack.pop() if not isinstance(tree, Tree): raise ValueError( "Got a proto which was just one Open operation: {}".format( proto_tree)) return tree, score
def initial_state_from_words(word_lists, model): # TODO: stop reversing the words preterminal_lists = [] for words in word_lists: preterminals = [] for word, tag in reversed(words): word_node = Tree(label=word) tag_node = Tree(label=tag, children=[word_node]) preterminals.append(tag_node) preterminal_lists.append(preterminals) return initial_state_from_preterminals(preterminal_lists, model, gold_trees=None)
def build_constituents(self, labels, children_lists): label_hx = [ self.open_node_embedding( self.open_node_tensors[self.open_node_map[label]]) for label in labels ] max_length = max(len(children) for children in children_lists) zeros = torch.zeros(self.hidden_size, device=label_hx[0].device) node_hx = [[child.output for child in children] for children in children_lists] # weirdly, this is faster than using pack_sequence unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx)) for lhx, nhx in zip(label_hx, node_hx)] unpacked_hx = [ self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx ] packed_hx = torch.stack(unpacked_hx, axis=1) packed_hx = torch.nn.utils.rnn.pack_padded_sequence( packed_hx, [len(x) + 2 for x in children_lists], enforce_sorted=False) lstm_output = self.constituent_reduce_lstm(packed_hx) # take just the output of the final layer # result of lstm is ouput, (hx, cx) # so [1][0] gets hx # [1][0][-1] is the final output # will be shape len(children_lists) * 2, hidden_size for bidirectional # where forward outputs are -2 and backwards are -1 lstm_output = lstm_output[1][0] forward_hx = lstm_output[-2, :] backward_hx = lstm_output[-1, :] hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1)) hx = self.nonlinearity(hx) constituents = [] for idx, (label, children) in enumerate(zip(labels, children_lists)): children = [child.value for child in children] if isinstance(label, str): node = Tree(label=label, children=children) else: for value in reversed(label): node = Tree(label=value, children=children) children = node constituents.append(Constituent(value=node, hx=hx[idx, :])) return constituents
def test_compound_constituents(): # TODO: add skinny trees like this to the various transition tests text = "((VP (VB Unban)))" trees = tree_reader.read_trees(text) assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')] text = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" trees = tree_reader.read_trees(text) assert Tree.get_compound_constituents(trees) == [('PP', ), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP', )] text = "((VP (VB Unban))) (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" trees = tree_reader.read_trees(text) assert Tree.get_compound_constituents(trees) == [('PP', ), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP', )]
def test_root_labels(): text = "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) assert ["ROOT"] == Tree.get_root_labels(trees) text = ( "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" ) trees = tree_reader.read_trees(text) assert ["ROOT"] == Tree.get_root_labels(trees) text = "(FOO) (BAR)" trees = tree_reader.read_trees(text) assert ["BAR", "FOO"] == Tree.get_root_labels(trees)
def build_constituents(self, labels, children_lists): constituents = [] for label, children in zip(labels, children_lists): if isinstance(label, str): label = (label, ) for value in reversed(label): children = Tree(label=value, children=children) constituents.append(children) return constituents
def test_rare_words(): """ Test getting the unique words from a tree """ text = "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))" trees = tree_reader.read_trees(text) words = Tree.get_rare_words(trees, 0.5) expected = ['Who', 'in', 'sits'] assert words == expected
def test_unique_tags(): """ Test getting the unique tags from a tree """ text = "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) tags = Tree.get_unique_tags(trees) expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP'] assert tags == expected
def unary_transform(self, constituents, labels): top_constituent = constituents.value node = top_constituent.value hx = top_constituent.output for label in reversed(labels): node = Tree(label=label, children=[node]) hx = self.unary_transforms[label](hx) # non-linearity after the unary transform hx = self.nonlinearity(hx) top_constituent = Constituent(value=node, hx=hx) return top_constituent
def test_unique_labels(): """ Test getting the unique labels from a tree Assumes tree_reader works, which should be fine since it is tested elsewhere """ text = "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) labels = Tree.get_unique_constituent_labels(trees) expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP'] assert labels == expected
def recursive_open_tree(token_iterator, at_root, broken_ok): """ Build a tree from the tokens in the token_iterator """ # TODO: unwind the recursion text = [] children = [] token = next(token_iterator, None) while token is not None: if token is OPEN_PAREN: children.append(recursive_open_tree(token_iterator, at_root=False, broken_ok=broken_ok)) elif token is CLOSE_PAREN: if len(text) == 0: if at_root: return Tree(label="ROOT", children=children) elif broken_ok: return Tree(label=None, children=children) else: raise ValueError("Found a tree with no label on a node! Line number %d" % token_iterator.line_num) pieces = " ".join(text).split() if len(pieces) == 1: return Tree(label=pieces[0], children=children) # the assumption here is that a language such as VI may # have spaces in the words, but it still represents # just one child label = pieces[0] child_label = " ".join(pieces[1:]) if len(children) > 0: if broken_ok: return Tree(label=label, children=children + [Tree(label=child_label)]) else: raise ValueError("Found a tree with both text children and bracketed children! Line number %d" % token_iterator.line_num) return Tree(label=label, children=Tree(label=child_label)) else: text.append(token) token = next(token_iterator, None)
def initial_state_from_gold_trees(trees, model): # reversed so we put the words on the stack backwards preterminal_lists = [[Tree(label=pt.label, children=Tree(label=pt.children[0].label)) for pt in tree.yield_reversed_preterminals()] for tree in trees] return initial_state_from_preterminals(preterminal_lists, model, gold_trees=trees)
def test_leaf_preterminal(): foo = Tree(label="foo") assert foo.is_leaf() assert not foo.is_preterminal() assert len(foo.children) == 0 assert str(foo) == 'foo' bar = Tree(label="bar", children=foo) assert not bar.is_leaf() assert bar.is_preterminal() assert len(bar.children) == 1 assert str(bar) == "(bar foo)" baz = Tree(label="baz", children=[bar]) assert not baz.is_leaf() assert not baz.is_preterminal() assert len(baz.children) == 1 assert str(baz) == "(baz (bar foo))"
def unary_transform(self, constituents, labels): top_constituent = constituents.value for label in reversed(labels): top_constituent = Tree(label=label, children=[top_constituent]) return top_constituent