예제 #1
0
    def test_print_with_children(self):
        tree = TokenTree(token={
            "id": "X",
            "deprel": "Y",
            "test": "data"
        },
                         children=[
                             TokenTree(token={
                                 "id": "X",
                                 "deprel": "Y",
                                 "test": "data"
                             },
                                       children=[]),
                             TokenTree(token={
                                 "id": "X",
                                 "deprel": "Y",
                                 "test": "data"
                             },
                                       children=[]),
                         ])
        result = capture_print(tree.print_tree)

        self.assertEqual(
            result,
            dedent("""\
            (deprel:Y) test:data [X]
                (deprel:Y) test:data [X]
                (deprel:Y) test:data [X]
        """))
예제 #2
0
파일: parse.py 프로젝트: daemon/vizbert
def compute_mst(distance_matrix: torch.Tensor,
                tokens: TokenList,
                ignore_punct=True) -> TokenTree:
    open_set = set([id_wrap(x) for x in tokens.copy()])
    closed_set = set()
    if ignore_punct:
        open_set = open_set - set(
            list(filter(lambda x: is_punctuation(x.obj['form']), open_set)))
    treenodes = {}
    root = None
    while open_set:
        if not closed_set:
            token = open_set.pop().obj
            treenodes[token['id']] = root = TokenTree(token, [])
            closed_set.add(id_wrap(token))
            continue
        grow_node_from = None
        grow_node_to = None
        grow_dist = np.inf
        for onode in open_set:
            onode = onode.obj
            for cnode in closed_set:
                cnode = cnode.obj
                dist = distance_matrix[onode['id'] - 1, cnode['id'] - 1]
                if dist < grow_dist:
                    grow_dist = dist
                    grow_node_from = cnode
                    grow_node_to = onode
        treenodes[grow_node_to['id']] = node = TokenTree(grow_node_to, [])
        treenodes[grow_node_from['id']].children.append(node)
        closed_set.add(id_wrap(grow_node_to))
        open_set.remove(id_wrap(grow_node_to))
    return root
예제 #3
0
 def test_simple_tree(self):
     tokenlist = TokenList([
         Token([("id", 2), ("form", "dog"), ("head", 0)]),
         Token([("id", 1), ("form", "a"), ("head", 2)]),
     ])
     tree = TokenTree(token=Token([("id", 2), ("form", "dog"),
                                   ("head", 0)]),
                      children=[
                          TokenTree(token=Token([("id", 1), ("form", "a"),
                                                 ("head", 2)]),
                                    children=[])
                      ])
     self.assertTreeEqual(tokenlist.to_tree(), tree)
예제 #4
0
 def test_to_tree(self):
     tokenlist = TokenList([
         OrderedDict([("id", 2), ("form", "dog"), ("head", 0)]),
         OrderedDict([("id", 1), ("form", "a"), ("head", 2)]),
     ])
     tree = TokenTree(
         token=OrderedDict([("id", 2), ("form", "dog"), ("head", 0)]),
         children=[TokenTree(
             token=OrderedDict([("id", 1), ("form", "a"), ("head", 2)]),
             children=[]
         )]
     )
     self.assertEqual(tokenlist.to_tree(), tree)
예제 #5
0
    def test_flatten(self):
        tree = TokenTree(token=OrderedDict([("id", 2), ("form", "dog")]),
                         children=[
                             TokenTree(token=OrderedDict([("id", 1),
                                                          ("form", "a")]),
                                       children=[])
                         ])
        self.assertEqual(
            tree.serialize(),
            dedent("""\
                1\ta
                2\tdog

            """))
        tree = TokenTree(token=OrderedDict([("id", 1), ("form", "dog")]),
                         children=[
                             TokenTree(token=OrderedDict([("id", 2),
                                                          ("form", "a")]),
                                       children=[])
                         ])
        self.assertEqual(
            tree.serialize(),
            dedent("""\
                1\tdog
                2\ta

            """))
예제 #6
0
 def test_removes_negative_nodes(self):
     tokenlist = TokenList([
         Token([("id", 2), ("form", "dog"), ("head", 0)]),
         Token([("id", 1), ("form", "a"), ("head", 2)]),
         Token([("id", 3), ("form", "😍"), ("head", -1)]),
     ])
     tree = TokenTree(token=Token([("id", 2), ("form", "dog"),
                                   ("head", 0)]),
                      children=[
                          TokenTree(token=Token([("id", 1), ("form", "a"),
                                                 ("head", 2)]),
                                    children=[])
                      ])
     self.assertTreeEqual(tokenlist.to_tree(), tree)
예제 #7
0
    def test_eq(self):
        metadata = {"meta": "data"}

        tokentree1 = TokenTree(token={"id": 1}, children=[TokenTree(token={"id": 2}, children=[])])
        tokentree1.metadata = metadata

        tokentree2 = TokenTree(token={"id": 1}, children=[])
        self.assertNotEqual(tokentree1, tokentree2)

        tokentree2.metadata = metadata
        self.assertNotEqual(tokentree1, tokentree2)

        tokentree2.children = [TokenTree(token={"id": 2}, children=[])]
        self.assertEqual(tokentree1, tokentree2)
예제 #8
0
 def test_print_simple(self):
     tree = TokenTree(token={
         "id": "X",
         "deprel": "Y",
         "test": "data"
     },
                      children=[])
     result = capture_print(tree.print_tree)
     self.assertEqual(result, "(deprel:Y) test:data [X]\n")
예제 #9
0
 def test_multiple_root_nodes(self):
     tokenlist = TokenList([
         Token([('id', 1), ('form', 'To'), ('head', 0)]),
         Token([('id', 2), ('form', 'appear'), ('head', 1)]),
         Token([('id', 4), ('form', 'EMNLP'), ('head', 0)]),
         Token([('id', 5), ('form', '2014'), ('head', 4)]),
         Token([('id', 6), ('form', 'Yay!'), ('head', 0)]),
     ])
     tree = TokenTree(
         token=Token([("id", 0), ("form", "_"), ("deprel", "root")]),
         children=[
             TokenTree(
                 token=Token([("id", 1), ("form", "To"), ("head", 0)]),
                 children=[TokenTree(
                     token=Token([("id", 2), ("form", "appear"), ("head", 1)]),
                     children=[]
                 )]
             ),
             TokenTree(
                 token=Token([("id", 4), ("form", "EMNLP"), ("head", 0)]),
                 children=[TokenTree(
                     token=Token([("id", 5), ("form", "2014"), ("head", 4)]),
                     children=[]
                 )]
             ),
             TokenTree(
                 token=Token([("id", 6), ("form", "Yay!"), ("head", 0)]),
                 children=[]
             ),
         ]
     )
     self.assertTreeEqual(tokenlist.to_tree(), tree)
예제 #10
0
    def test_metadata(self):
        tree = TokenTree(token={"id": 1, "form": "hej"}, children=[])
        metadata = {"meta": "data"}
        tree.set_metadata(metadata)
        self.assertEqual(tree.metadata, metadata)

        tree = TokenTree(token={"id": 1, "form": "hej"}, children=[], metadata={"meta": "data"})
        self.assertEqual(tree.metadata, metadata)
예제 #11
0
 def test_tree_without_id(self):
     tree = TokenTree(token={"form": "hej", "deprel": "nmod"}, children=[])
     with self.assertRaises(ParseException):
         capture_print(tree.print_tree)
예제 #12
0
 def test_print_empty_list(self):
     tree = TokenTree(None, [])
     with self.assertRaises(ParseException):
         capture_print(tree.print_tree)
예제 #13
0
 def test_missing_id(self):
     tree = TokenTree(token={"form": "hej"}, children=[])
     with self.assertRaises(ParseException):
         tree.serialize()
예제 #14
0
파일: __init__.py 프로젝트: zoharai/conllu
 def _create_tree(head_to_token_mapping, id_=0):
     return [
         TokenTree(child, _create_tree(head_to_token_mapping, child["id"]))
         for child in head_to_token_mapping[id_]
     ]