def from_tree(proto_tree):
    """
    Convert a FlattenedParseTree back into a Tree

    returns Tree, score
      (score might be None if it is missing)
    """
    score = None
    stack = deque()
    for node in proto_tree.nodes:
        if node.HasField("score") and score is None:
            score = node.score

        if node.openNode:
            if len(stack) > 0 and isinstance(
                    stack[-1], FlattenedParseTree.Node) and stack[-1].openNode:
                raise ValueError(
                    "Got a proto with no label on a node: {}".format(
                        proto_tree))
            stack.append(node)
            continue
        if not node.closeNode:
            child = Tree(label=node.value)
            # TODO: do something with the score
            stack.append(child)
            continue

        # must be a close operation...
        if len(stack) <= 1:
            raise ValueError(
                "Got a proto with too many close operations: {}".format(
                    proto_tree))
        # on a close operation, pop until we hit the open
        # then turn everything in that span into a new node
        children = []
        nextNode = stack.pop()
        while not isinstance(nextNode, FlattenedParseTree.Node):
            children.append(nextNode)
            nextNode = stack.pop()
        if len(children) == 0:
            raise ValueError(
                "Got a proto with an open immediately followed by a close: {}".
                format(proto_tree))
        children.reverse()
        label = children[0]
        children = children[1:]
        subtree = Tree(label=label.label, children=children)
        stack.append(subtree)

    if len(stack) > 1:
        raise ValueError(
            "Got a proto which does not close all of the nodes: {}".format(
                proto_tree))
    tree = stack.pop()
    if not isinstance(tree, Tree):
        raise ValueError(
            "Got a proto which was just one Open operation: {}".format(
                proto_tree))
    return tree, score
Esempio n. 2
0
def initial_state_from_words(word_lists, model):
    # TODO: stop reversing the words
    preterminal_lists = []
    for words in word_lists:
        preterminals = []
        for word, tag in reversed(words):
            word_node = Tree(label=word)
            tag_node = Tree(label=tag, children=[word_node])
            preterminals.append(tag_node)
        preterminal_lists.append(preterminals)
    return initial_state_from_preterminals(preterminal_lists, model, gold_trees=None)
Esempio n. 3
0
    def build_constituents(self, labels, children_lists):
        label_hx = [
            self.open_node_embedding(
                self.open_node_tensors[self.open_node_map[label]])
            for label in labels
        ]

        max_length = max(len(children) for children in children_lists)
        zeros = torch.zeros(self.hidden_size, device=label_hx[0].device)
        node_hx = [[child.output for child in children]
                   for children in children_lists]
        # weirdly, this is faster than using pack_sequence
        unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx))
                       for lhx, nhx in zip(label_hx, node_hx)]
        unpacked_hx = [
            self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx
        ]
        packed_hx = torch.stack(unpacked_hx, axis=1)
        packed_hx = torch.nn.utils.rnn.pack_padded_sequence(
            packed_hx, [len(x) + 2 for x in children_lists],
            enforce_sorted=False)
        lstm_output = self.constituent_reduce_lstm(packed_hx)
        # take just the output of the final layer
        #   result of lstm is ouput, (hx, cx)
        #   so [1][0] gets hx
        #      [1][0][-1] is the final output
        # will be shape len(children_lists) * 2, hidden_size for bidirectional
        # where forward outputs are -2 and backwards are -1
        lstm_output = lstm_output[1][0]
        forward_hx = lstm_output[-2, :]
        backward_hx = lstm_output[-1, :]

        hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1))
        hx = self.nonlinearity(hx)

        constituents = []
        for idx, (label, children) in enumerate(zip(labels, children_lists)):
            children = [child.value for child in children]
            if isinstance(label, str):
                node = Tree(label=label, children=children)
            else:
                for value in reversed(label):
                    node = Tree(label=value, children=children)
                    children = node
            constituents.append(Constituent(value=node, hx=hx[idx, :]))
        return constituents
Esempio n. 4
0
def test_compound_constituents():
    # TODO: add skinny trees like this to the various transition tests
    text = "((VP (VB Unban)))"
    trees = tree_reader.read_trees(text)
    assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')]

    text = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
    trees = tree_reader.read_trees(text)
    assert Tree.get_compound_constituents(trees) == [('PP', ),
                                                     ('ROOT', 'SBARQ'),
                                                     ('SQ', 'VP'), ('WHNP', )]

    text = "((VP (VB Unban)))   (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
    trees = tree_reader.read_trees(text)
    assert Tree.get_compound_constituents(trees) == [('PP', ),
                                                     ('ROOT', 'SBARQ'),
                                                     ('ROOT', 'VP'),
                                                     ('SQ', 'VP'), ('WHNP', )]
Esempio n. 5
0
def test_root_labels():
    text = "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
    trees = tree_reader.read_trees(text)
    assert ["ROOT"] == Tree.get_root_labels(trees)

    text = (
        "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
        +
        "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
        +
        "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
    )
    trees = tree_reader.read_trees(text)
    assert ["ROOT"] == Tree.get_root_labels(trees)

    text = "(FOO) (BAR)"
    trees = tree_reader.read_trees(text)
    assert ["BAR", "FOO"] == Tree.get_root_labels(trees)
Esempio n. 6
0
 def build_constituents(self, labels, children_lists):
     constituents = []
     for label, children in zip(labels, children_lists):
         if isinstance(label, str):
             label = (label, )
         for value in reversed(label):
             children = Tree(label=value, children=children)
         constituents.append(children)
     return constituents
Esempio n. 7
0
def test_rare_words():
    """
    Test getting the unique words from a tree
    """
    text = "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))  ((SBARQ (NP (DT this) (NN seat)) (. ?)))"

    trees = tree_reader.read_trees(text)

    words = Tree.get_rare_words(trees, 0.5)
    expected = ['Who', 'in', 'sits']
    assert words == expected
Esempio n. 8
0
def test_unique_tags():
    """
    Test getting the unique tags from a tree
    """
    text = "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"

    trees = tree_reader.read_trees(text)

    tags = Tree.get_unique_tags(trees)
    expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP']
    assert tags == expected
Esempio n. 9
0
 def unary_transform(self, constituents, labels):
     top_constituent = constituents.value
     node = top_constituent.value
     hx = top_constituent.output
     for label in reversed(labels):
         node = Tree(label=label, children=[node])
         hx = self.unary_transforms[label](hx)
         # non-linearity after the unary transform
         hx = self.nonlinearity(hx)
     top_constituent = Constituent(value=node, hx=hx)
     return top_constituent
Esempio n. 10
0
def test_unique_labels():
    """
    Test getting the unique labels from a tree

    Assumes tree_reader works, which should be fine since it is tested elsewhere
    """
    text = "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"

    trees = tree_reader.read_trees(text)

    labels = Tree.get_unique_constituent_labels(trees)
    expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP']
    assert labels == expected
Esempio n. 11
0
def recursive_open_tree(token_iterator, at_root, broken_ok):
    """
    Build a tree from the tokens in the token_iterator
    """
    # TODO: unwind the recursion
    text = []
    children = []

    token = next(token_iterator, None)
    while token is not None:
        if token is OPEN_PAREN:
            children.append(recursive_open_tree(token_iterator, at_root=False, broken_ok=broken_ok))
        elif token is CLOSE_PAREN:
            if len(text) == 0:
                if at_root:
                    return Tree(label="ROOT", children=children)
                elif broken_ok:
                    return Tree(label=None, children=children)
                else:
                    raise ValueError("Found a tree with no label on a node!  Line number %d" % token_iterator.line_num)

            pieces = " ".join(text).split()
            if len(pieces) == 1:
                return Tree(label=pieces[0], children=children)

            # the assumption here is that a language such as VI may
            # have spaces in the words, but it still represents
            # just one child
            label = pieces[0]
            child_label = " ".join(pieces[1:])
            if len(children) > 0:
                if broken_ok:
                    return Tree(label=label, children=children + [Tree(label=child_label)])
                else:
                    raise ValueError("Found a tree with both text children and bracketed children!  Line number %d" % token_iterator.line_num)
            return Tree(label=label, children=Tree(label=child_label))
        else:
            text.append(token)
        token = next(token_iterator, None)
Esempio n. 12
0
def initial_state_from_gold_trees(trees, model):
    # reversed so we put the words on the stack backwards
    preterminal_lists = [[Tree(label=pt.label, children=Tree(label=pt.children[0].label))
                          for pt in tree.yield_reversed_preterminals()]
                         for tree in trees]
    return initial_state_from_preterminals(preterminal_lists, model, gold_trees=trees)
Esempio n. 13
0
def test_leaf_preterminal():
    foo = Tree(label="foo")
    assert foo.is_leaf()
    assert not foo.is_preterminal()
    assert len(foo.children) == 0
    assert str(foo) == 'foo'

    bar = Tree(label="bar", children=foo)
    assert not bar.is_leaf()
    assert bar.is_preterminal()
    assert len(bar.children) == 1
    assert str(bar) == "(bar foo)"

    baz = Tree(label="baz", children=[bar])
    assert not baz.is_leaf()
    assert not baz.is_preterminal()
    assert len(baz.children) == 1
    assert str(baz) == "(baz (bar foo))"
Esempio n. 14
0
 def unary_transform(self, constituents, labels):
     top_constituent = constituents.value
     for label in reversed(labels):
         top_constituent = Tree(label=label, children=[top_constituent])
     return top_constituent