def test_set2set(): g = dgl.DGLGraph(nx.path_graph(10)) s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers print(s2s) # test#1: basic h0 = mx.nd.random.randn(g.number_of_nodes(), 5) h1 = s2s(h0, g) assert h1.shape[0] == 10 and h1.ndim == 1 # test#2: batched graph bg = dgl.batch([g, g, g]) h0 = mx.nd.random.randn(bg.number_of_nodes(), 5) h1 = s2s(h0, bg) assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
def test_set2set(): g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx()) ctx = F.ctx() s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers s2s.initialize(ctx=ctx) print(s2s) # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) h1 = s2s(g, h0) assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2 # test#2: batched graph bg = dgl.batch([g, g, g]) h0 = F.randn((bg.number_of_nodes(), 5)) h1 = s2s(bg, h0) assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2