예제 #1
0
def test_dense_graph_conv(g, norm_type):
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).tostype('default')
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
    conv.initialize(ctx=ctx)
    dense_conv.initialize(ctx=ctx)
    dense_conv.weight.set_data(conv.weight.data())
    dense_conv.bias.set_data(conv.bias.data())
    feat = F.randn((g.number_of_src_nodes(), 5))
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
    assert F.allclose(out_conv, out_dense_conv)
예제 #2
0
def test_dense_graph_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True)
    adj = g.adjacency_matrix(ctx=ctx).tostype('default')
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True)
    conv.initialize(ctx=ctx)
    dense_conv.initialize(ctx=ctx)
    dense_conv.weight.set_data(conv.weight.data())
    dense_conv.bias.set_data(conv.bias.data())
    feat = F.randn((100, 5))

    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
    assert F.allclose(out_conv, out_dense_conv)