コード例 #1
0
def read_tree_no_labels(parents):
    nodes = {}
    #parents = [p - 1 for p in parents]  # 1-indexed
    #print parents, len(parents)
    max_degree = 1
    for i in xrange(len(parents)):
        #print 'visit node:', i
        if i not in nodes:
            idx = i
            prev = None
            while True:
                #print 'read tree! idx:', idx
                node = tree_rnn.Node(
                    val=idx, origin_idx=idx)  # for now, val is just idx
                if prev is not None:
                    assert prev.val != node.val
                    node.add_child(prev)

                nodes[idx] = node
                #print 'nodes:', nodes.keys()
                parent = parents[idx]
                if parent in nodes:
                    #assert len(nodes[parent].children) < 2
                    # Note: in our setting, a parent not nessacerily has only 2 children
                    nodes[parent].add_child(node)
                    if len(nodes[parent].children) > max_degree:
                        max_degree = len(nodes[parent].children)
                    break
                elif parent == -1:
                    root = node
                    break

                prev = node
                idx = parent

    # ensure tree is completely binary
    '''for node in nodes.itervalues():
        if not node.children:
            continue
        assert len(node.children) == 2
    '''
    # overwrite vals to match sentence indices -
    # only leaves correspond to sentence tokens
    #leaf_idx = 0
    #for node in nodes.itervalues():
    #    if node.children:
    #        node.val = None
    #    else:
    #        node.val = leaf_idx
    #        leaf_idx += 1

    return root, max_degree
コード例 #2
0
def read_tree(list, vocab):
    att_list = []
    nodes = []
    root = None
    for i in range(len(list)):
        att_list.append(list[i].split())
        word = att_list[i][1]
        val = vocab.index(word)
        nodes.append(tree_rnn.Node(val))
    for i in range(len(list)):
        parent = int(att_list[i][6]) - 1
        if parent >= 0:
            nodes[parent].add_child(nodes[i])
        elif parent == -1:
            root = nodes[i]
    return root
def read_tree(parents, labels):
    nodes = {}
    parents = [p - 1 for p in parents]  # 1-indexed

    for i in xrange(len(parents)):
        if i not in nodes:
            idx = i
            prev = None
            while True:
                node = tree_rnn.Node(val=idx)  # for now, val is just idx
                if prev is not None:
                    assert prev.val != node.val
                    node.add_child(prev)

                node.label = labels[idx]
                nodes[idx] = node

                parent = parents[idx]
                if parent in nodes:
                    nodes[parent].add_child(node)
                    break
                elif parent == -1:
                    root = node
                    break

                prev = node
                idx = parent

    # ensure tree is connected: by confirming that there is only one root
    num_roots = sum(node.parent is None for node in nodes.itervalues())
    assert num_roots == 1, num_roots

    # overwrite vals to match sentence indices -
    # only leaves correspond to sentence tokens
    leaf_idx = 0
    for node in nodes.itervalues():
        if node.children:
            node.val = None
        else:
            node.val = leaf_idx
            leaf_idx += 1

    max_degree = max(len(node.children) for node in nodes.itervalues())
    # max_degree is the maximum number of children of a node

    return max_degree, root
コード例 #4
0
ファイル: data_utils.py プロジェクト: ibab/tree_rnn
def read_tree(parents, labels):
    nodes = {}
    parents = [p - 1 for p in parents]  # 1-indexed
    for i in xrange(len(parents)):
        if i not in nodes:
            idx = i
            prev = None
            while True:
                node = tree_rnn.Node(val=idx)  # for now, val is just idx
                if prev is not None:
                    assert prev.val != node.val
                    node.add_child(prev)

                node.label = labels[idx]
                nodes[idx] = node

                parent = parents[idx]
                if parent in nodes:
                    assert len(nodes[parent].children) < 2
                    nodes[parent].add_child(node)
                    break
                elif parent == -1:
                    root = node
                    break

                prev = node
                idx = parent

    # ensure tree is completely binary
    for node in nodes.itervalues():
        if not node.children:
            continue
        assert len(node.children) == 2

    # overwrite vals to match sentence indices -
    # only leaves correspond to sentence tokens
    leaf_idx = 0
    for node in nodes.itervalues():
        if node.children:
            node.val = None
        else:
            node.val = leaf_idx
            leaf_idx += 1

    return root
コード例 #5
0
def test_irregular_tree():
    model = DummyTreeRNN(8, 2, 2, 1, degree=4, irregular_tree=True)
    emb = model.embeddings.get_value()

    root = tree_rnn.Node(3)
    c1 = tree_rnn.Node(1)
    c2 = tree_rnn.Node(2)
    c3 = tree_rnn.Node(3)
    c4 = tree_rnn.Node(4)
    c5 = tree_rnn.Node(5)
    c6 = tree_rnn.Node(6)
    root.add_children([c1, c2, c3, c4])
    c1.add_children([c5])
    c5.add_children([c6])

    root_emb = model.evaluate(root)
    expected = emb[3] + emb[2] * emb[3] * emb[4] * (emb[1] + emb[5] + emb[6])
    assert_array_almost_equal(expected, root_emb)
コード例 #6
0
def read_tree(list, vocab):
    att_list = []
    nodes = []
    root = None
    for i in range(len(list)):
        att_list.append(list[i].split())
        word = att_list[i][1]
        tag = att_list[i][3]
        if vocab is None:
            val = word
            tag_idx = 0
        else:
            val = vocab.index(word)
            tag_idx = vocab.indexoftag(tag)
        nodes.append(tree_rnn.Node(val, i, tag_idx))
    for i in range(len(list)):
        parent = int(att_list[i][6]) - 1
        if parent >= 0:
            nodes[parent].add_child(nodes[i])
        elif parent == -1:
            root = nodes[i]
    return root
コード例 #7
0
def test_tree_rnn():
    model = DummyTreeRNN(8, 2, 2, 1, degree=2)
    emb = model.embeddings.get_value()

    root = tree_rnn.Node(3)
    c1 = tree_rnn.Node(1)
    c2 = tree_rnn.Node(2)
    root.add_children([c1, c2])

    root_emb = model.evaluate(root)
    expected = emb[3] + emb[1] * emb[2]
    assert_array_almost_equal(expected, root_emb)

    cc1 = tree_rnn.Node(5)
    cc2 = tree_rnn.Node(2)
    c2.add_children([cc1, cc2])

    root_emb = model.evaluate(root)
    expected = emb[3] + (emb[2] + emb[5] * emb[2]) * emb[1]
    assert_array_almost_equal(expected, root_emb)

    ccc1 = tree_rnn.Node(5)
    ccc2 = tree_rnn.Node(4)
    cc1.add_children([ccc1, ccc2])

    root_emb = model.evaluate(root)
    expected = emb[3] + (emb[2] + (emb[5] + emb[5] * emb[4]) * emb[2]) * emb[1]
    assert_array_almost_equal(expected, root_emb)

    # check step works without error
    model.train_step(root, np.array([0]).astype(theano.config.floatX))

    # degree > 2
    model = DummyTreeRNN(10, 2, 2, 1, degree=3)
    emb = model.embeddings.get_value()

    root = tree_rnn.Node(0)
    c1 = tree_rnn.Node(1)
    c2 = tree_rnn.Node(2)
    c3 = tree_rnn.Node(3)
    root.add_children([c1, c2, c3])

    cc1 = tree_rnn.Node(1)
    cc2 = tree_rnn.Node(2)
    cc3 = tree_rnn.Node(3)
    cc4 = tree_rnn.Node(4)
    cc5 = tree_rnn.Node(5)
    cc6 = tree_rnn.Node(6)
    cc7 = tree_rnn.Node(7)
    cc8 = tree_rnn.Node(8)
    cc9 = tree_rnn.Node(9)

    c1.add_children([cc1, cc2, cc3])
    c2.add_children([cc4, cc5, cc6])
    c3.add_children([cc7, cc8, cc9])

    root_emb = model.evaluate(root)
    expected = \
        emb[0] + ((emb[1] + emb[1] * emb[2] * emb[3]) *
                  (emb[2] + emb[4] * emb[5] * emb[6]) *
                  (emb[3] + emb[7] * emb[8] * emb[9]))
    assert_array_almost_equal(expected, root_emb)

    # check step works without error
    model.train_step(root, np.array([0]).astype(theano.config.floatX))