Exemple #1
0
def compute_mst(distance_matrix: torch.Tensor,
                tokens: TokenList,
                ignore_punct=True) -> TokenTree:
    open_set = set([id_wrap(x) for x in tokens.copy()])
    closed_set = set()
    if ignore_punct:
        open_set = open_set - set(
            list(filter(lambda x: is_punctuation(x.obj['form']), open_set)))
    treenodes = {}
    root = None
    while open_set:
        if not closed_set:
            token = open_set.pop().obj
            treenodes[token['id']] = root = TokenTree(token, [])
            closed_set.add(id_wrap(token))
            continue
        grow_node_from = None
        grow_node_to = None
        grow_dist = np.inf
        for onode in open_set:
            onode = onode.obj
            for cnode in closed_set:
                cnode = cnode.obj
                dist = distance_matrix[onode['id'] - 1, cnode['id'] - 1]
                if dist < grow_dist:
                    grow_dist = dist
                    grow_node_from = cnode
                    grow_node_to = onode
        treenodes[grow_node_to['id']] = node = TokenTree(grow_node_to, [])
        treenodes[grow_node_from['id']].children.append(node)
        closed_set.add(id_wrap(grow_node_to))
        open_set.remove(id_wrap(grow_node_to))
    return root
Exemple #2
0
 def test_copy(self):
     tokenlist1 = TokenList([{
         "id": 1
     }, {
         "id": 2
     }, {
         "id": 3
     }], {"meta": "data"})
     tokenlist2 = tokenlist1.copy()
     self.assertIsNot(tokenlist1, tokenlist2)
     self.assertEqual(tokenlist1, tokenlist2)