def test_edge_conv(g, idtype, out_dim): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() edge_conv = nn.EdgeConv(5, out_dim) edge_conv.initialize(ctx=ctx) print(edge_conv) # test #1: basic h0 = F.randn((g.number_of_src_nodes(), 5)) h1 = edge_conv(g, h0) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
def test_edge_conv(): g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) ctx = F.ctx() edge_conv = nn.EdgeConv(5, 2) edge_conv.initialize(ctx=ctx) print(edge_conv) # test #1: basic h0 = F.randn((g.number_of_nodes(), 5)) h1 = edge_conv(g, h0) assert h1.shape == (g.number_of_nodes(), 2)
def test_edge_conv(g): ctx = F.ctx() edge_conv = nn.EdgeConv(5, 2) edge_conv.initialize(ctx=ctx) print(edge_conv) # test #1: basic h0 = F.randn((g.number_of_src_nodes(), 5)) if not g.is_homograph(): # bipartite h1 = edge_conv(g, (h0, h0[:10])) else: h1 = edge_conv(g, h0) assert h1.shape == (g.number_of_dst_nodes(), 2)