示例#1
0
def test_dense_sage_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
    sage = nn.SAGEConv(5, 2, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
    feat = F.randn((100, 5))
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage)
示例#2
0
文件: test_nn.py 项目: weibao918/dgl
def test_dense_sage_conv(g):
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
    sage = nn.SAGEConv(5, 2, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
    if len(g.ntypes) == 2:
        feat = (F.randn(
            (g.number_of_src_nodes(), 5)), F.randn(
                (g.number_of_dst_nodes(), 5)))
    else:
        feat = F.randn((g.number_of_nodes(), 5))
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage), g