def test_sg_conv(): g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) ctx = F.ctx() sgc = nn.SGConv(5, 2, 2) sgc.initialize(ctx=ctx) print(sgc) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) h1 = sgc(g, h0) assert h1.shape == (g.number_of_nodes(), 2)
def test_sg_conv(out_dim): g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx()) g = dgl.add_self_loop(g) ctx = F.ctx() sgc = nn.SGConv(5, out_dim, 2) sgc.initialize(ctx=ctx) print(sgc) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) h1 = sgc(g, h0) assert h1.shape == (g.number_of_nodes(), out_dim)