def test_gmm_conv(): ctx = F.ctx() g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, h0, pseudo) assert h1.shape == (g.number_of_nodes(), 2) g = dgl.graph(nx.erdos_renyi_graph(20, 0.3)) gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, h0, pseudo) assert h1.shape == (g.number_of_nodes(), 2) g = dgl.bipartite(sp.sparse.random(20, 10, 0.1)) gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) # test #1: basic h0 = F.randn((g.number_of_src_nodes(), 5)) hd = F.randn((g.number_of_dst_nodes(), 4)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, (h0, hd), pseudo) assert h1.shape == (g.number_of_dst_nodes(), 2)
def test_gmm_conv(g, idtype): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) h0 = F.randn((g.number_of_src_nodes(), 5)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, h0, pseudo) assert h1.shape == (g.number_of_dst_nodes(), 2)
def test_gmm_conv(): ctx = F.ctx() g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, h0, pseudo) assert h1.shape == (g.number_of_nodes(), 2) g = dgl.graph(nx.erdos_renyi_graph(20, 0.3)) gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, h0, pseudo) assert h1.shape == (g.number_of_nodes(), 2) g = dgl.bipartite(sp.sparse.random(20, 10, 0.1)) gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) # test #1: basic h0 = F.randn((g.number_of_src_nodes(), 5)) hd = F.randn((g.number_of_dst_nodes(), 4)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, (h0, hd), pseudo) assert h1.shape == (g.number_of_dst_nodes(), 2) g = dgl.graph(sp.sparse.random(100, 100, density=0.001)) seed_nodes = np.unique(g.edges()[1].asnumpy()) block = dgl.to_block(g, seed_nodes) gmm_conv = nn.GMMConv(5, 2, 5, 3, 'mean') gmm_conv.initialize(ctx=ctx) h0 = F.randn((block.number_of_src_nodes(), 5)) pseudo = F.randn((block.number_of_edges(), 5)) h = gmm_conv(block, h0, pseudo) assert h.shape[0] == block.number_of_dst_nodes() assert h.shape[-1] == 2
def test_gmm_conv(): g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) ctx = F.ctx() gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) print(gmm_conv) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, h0, pseudo) assert h1.shape == (g.number_of_nodes(), 2)