Пример #1
0
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    def _init(shape, dtype, ctx, ids):
        return F.copy_to(F.astype(F.randn(shape), dtype), ctx)

    g.set_n_initializer(_init)
    g.set_e_initializer(_init)

    def _message(edges):
        return {
            'm':
            edges.src['h1'] + edges.dst['h2'] + edges.data['h1'] +
            edges.data['h2']
        }

    def _reduce(nodes):
        return {'h': F.sum(nodes.mailbox['m'], 1)}

    def _apply(nodes):
        return {'h': nodes.data['h']}

    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)

    # add nodes and edges
    g.add_nodes(N)
    g.ndata.update({'h1': F.randn((N, D)), 'h2': F.randn((N, D))})
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(1, 0)
    g.edata.update({'h1': F.randn((2, D)), 'h2': F.randn((2, D))})
    g.send()
    expected = F.copy_to(F.ones((g.number_of_edges(), ), dtype=F.int64),
                         F.cpu())
    assert F.array_equal(g._get_msg_index().tousertensor(), expected)

    # add more edges
    g.add_edges([0, 2], [2, 0], {'h1': F.randn((2, D))})
    g.send(([0, 2], [2, 0]))
    g.recv(0)

    g.add_edge(1, 2)
    g.edges[4].data['h1'] = F.randn((1, D))
    g.send((1, 2))
    g.recv([1, 2])

    h = g.ndata.pop('h')

    # a complete round of send and recv
    g.send()
    g.recv()
    assert F.allclose(h, g.ndata['h'])
Пример #2
0
def generate_graph(grad=False):
    g = DGLGraph()
    g.add_nodes(10) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    # 16 edges
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    ncol = Variable(th.randn(10, D), requires_grad=grad)
    ecol = Variable(th.randn(16, D), requires_grad=grad)
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
    return g
Пример #3
0
def generate_graph(grad=False):
    g = DGLGraph()
    g.add_nodes(10)  # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    # 16 edges
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    ncol = F.randn((10, D))
    ecol = F.randn((16, D))
    if grad:
        ncol = F.attach_grad(ncol)
        ecol = F.attach_grad(ecol)
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
    return g
Пример #4
0
def generate_graph(grad=False, readonly=False):
    if readonly:
        row_idx = []
        col_idx = []
        for i in range(1, 9):
            row_idx.append(0)
            col_idx.append(i)
            row_idx.append(i)
            col_idx.append(9)
        row_idx.append(9)
        col_idx.append(0)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(10, 10))
        g = DGLGraph(csr, readonly=True)
        ncol = mx.nd.random.normal(shape=(10, D))
        ecol = mx.nd.random.normal(shape=(17, D))
        if grad:
            ncol.attach_grad()
            ecol.attach_grad()
        g.ndata['h'] = ncol
        g.edata['w'] = ecol
        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
        return g
    else:
        g = DGLGraph()
        g.add_nodes(10)  # 10 nodes.
        # create a graph where 0 is the source and 9 is the sink
        for i in range(1, 9):
            g.add_edge(0, i)
            g.add_edge(i, 9)
        # add a back flow from 9 to 0
        g.add_edge(9, 0)
        ncol = mx.nd.random.normal(shape=(10, D))
        ecol = mx.nd.random.normal(shape=(17, D))
        if grad:
            ncol.attach_grad()
            ecol.attach_grad()
        g.ndata['h'] = ncol
        g.edata['w'] = ecol
        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
        return g