def test_print_with_children(self): tree = TokenTree(token={ "id": "X", "deprel": "Y", "test": "data" }, children=[ TokenTree(token={ "id": "X", "deprel": "Y", "test": "data" }, children=[]), TokenTree(token={ "id": "X", "deprel": "Y", "test": "data" }, children=[]), ]) result = capture_print(tree.print_tree) self.assertEqual( result, dedent("""\ (deprel:Y) test:data [X] (deprel:Y) test:data [X] (deprel:Y) test:data [X] """))
def compute_mst(distance_matrix: torch.Tensor, tokens: TokenList, ignore_punct=True) -> TokenTree: open_set = set([id_wrap(x) for x in tokens.copy()]) closed_set = set() if ignore_punct: open_set = open_set - set( list(filter(lambda x: is_punctuation(x.obj['form']), open_set))) treenodes = {} root = None while open_set: if not closed_set: token = open_set.pop().obj treenodes[token['id']] = root = TokenTree(token, []) closed_set.add(id_wrap(token)) continue grow_node_from = None grow_node_to = None grow_dist = np.inf for onode in open_set: onode = onode.obj for cnode in closed_set: cnode = cnode.obj dist = distance_matrix[onode['id'] - 1, cnode['id'] - 1] if dist < grow_dist: grow_dist = dist grow_node_from = cnode grow_node_to = onode treenodes[grow_node_to['id']] = node = TokenTree(grow_node_to, []) treenodes[grow_node_from['id']].children.append(node) closed_set.add(id_wrap(grow_node_to)) open_set.remove(id_wrap(grow_node_to)) return root
def test_simple_tree(self): tokenlist = TokenList([ Token([("id", 2), ("form", "dog"), ("head", 0)]), Token([("id", 1), ("form", "a"), ("head", 2)]), ]) tree = TokenTree(token=Token([("id", 2), ("form", "dog"), ("head", 0)]), children=[ TokenTree(token=Token([("id", 1), ("form", "a"), ("head", 2)]), children=[]) ]) self.assertTreeEqual(tokenlist.to_tree(), tree)
def test_to_tree(self): tokenlist = TokenList([ OrderedDict([("id", 2), ("form", "dog"), ("head", 0)]), OrderedDict([("id", 1), ("form", "a"), ("head", 2)]), ]) tree = TokenTree( token=OrderedDict([("id", 2), ("form", "dog"), ("head", 0)]), children=[TokenTree( token=OrderedDict([("id", 1), ("form", "a"), ("head", 2)]), children=[] )] ) self.assertEqual(tokenlist.to_tree(), tree)
def test_flatten(self): tree = TokenTree(token=OrderedDict([("id", 2), ("form", "dog")]), children=[ TokenTree(token=OrderedDict([("id", 1), ("form", "a")]), children=[]) ]) self.assertEqual( tree.serialize(), dedent("""\ 1\ta 2\tdog """)) tree = TokenTree(token=OrderedDict([("id", 1), ("form", "dog")]), children=[ TokenTree(token=OrderedDict([("id", 2), ("form", "a")]), children=[]) ]) self.assertEqual( tree.serialize(), dedent("""\ 1\tdog 2\ta """))
def test_removes_negative_nodes(self): tokenlist = TokenList([ Token([("id", 2), ("form", "dog"), ("head", 0)]), Token([("id", 1), ("form", "a"), ("head", 2)]), Token([("id", 3), ("form", "😍"), ("head", -1)]), ]) tree = TokenTree(token=Token([("id", 2), ("form", "dog"), ("head", 0)]), children=[ TokenTree(token=Token([("id", 1), ("form", "a"), ("head", 2)]), children=[]) ]) self.assertTreeEqual(tokenlist.to_tree(), tree)
def test_eq(self): metadata = {"meta": "data"} tokentree1 = TokenTree(token={"id": 1}, children=[TokenTree(token={"id": 2}, children=[])]) tokentree1.metadata = metadata tokentree2 = TokenTree(token={"id": 1}, children=[]) self.assertNotEqual(tokentree1, tokentree2) tokentree2.metadata = metadata self.assertNotEqual(tokentree1, tokentree2) tokentree2.children = [TokenTree(token={"id": 2}, children=[])] self.assertEqual(tokentree1, tokentree2)
def test_print_simple(self): tree = TokenTree(token={ "id": "X", "deprel": "Y", "test": "data" }, children=[]) result = capture_print(tree.print_tree) self.assertEqual(result, "(deprel:Y) test:data [X]\n")
def test_multiple_root_nodes(self): tokenlist = TokenList([ Token([('id', 1), ('form', 'To'), ('head', 0)]), Token([('id', 2), ('form', 'appear'), ('head', 1)]), Token([('id', 4), ('form', 'EMNLP'), ('head', 0)]), Token([('id', 5), ('form', '2014'), ('head', 4)]), Token([('id', 6), ('form', 'Yay!'), ('head', 0)]), ]) tree = TokenTree( token=Token([("id", 0), ("form", "_"), ("deprel", "root")]), children=[ TokenTree( token=Token([("id", 1), ("form", "To"), ("head", 0)]), children=[TokenTree( token=Token([("id", 2), ("form", "appear"), ("head", 1)]), children=[] )] ), TokenTree( token=Token([("id", 4), ("form", "EMNLP"), ("head", 0)]), children=[TokenTree( token=Token([("id", 5), ("form", "2014"), ("head", 4)]), children=[] )] ), TokenTree( token=Token([("id", 6), ("form", "Yay!"), ("head", 0)]), children=[] ), ] ) self.assertTreeEqual(tokenlist.to_tree(), tree)
def test_metadata(self): tree = TokenTree(token={"id": 1, "form": "hej"}, children=[]) metadata = {"meta": "data"} tree.set_metadata(metadata) self.assertEqual(tree.metadata, metadata) tree = TokenTree(token={"id": 1, "form": "hej"}, children=[], metadata={"meta": "data"}) self.assertEqual(tree.metadata, metadata)
def test_tree_without_id(self): tree = TokenTree(token={"form": "hej", "deprel": "nmod"}, children=[]) with self.assertRaises(ParseException): capture_print(tree.print_tree)
def test_print_empty_list(self): tree = TokenTree(None, []) with self.assertRaises(ParseException): capture_print(tree.print_tree)
def test_missing_id(self): tree = TokenTree(token={"form": "hej"}, children=[]) with self.assertRaises(ParseException): tree.serialize()
def _create_tree(head_to_token_mapping, id_=0): return [ TokenTree(child, _create_tree(head_to_token_mapping, child["id"])) for child in head_to_token_mapping[id_] ]