def test_glob_att_pool(): g = dgl.DGLGraph(nx.path_graph(10)) gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10)) print(gap) # test#1: basic h0 = mx.nd.random.randn(g.number_of_nodes(), 5) h1 = gap(h0, g) assert h1.shape[0] == 10 and h1.ndim == 1 # test#2: batched graph bg = dgl.batch([g, g, g, g]) h0 = mx.nd.random.randn(bg.number_of_nodes(), 5) h1 = gap(h0, bg) assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
def test_glob_att_pool(): g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx()) ctx = F.ctx() gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10)) gap.initialize(ctx=ctx) print(gap) # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) h1 = gap(g, h0) assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2 # test#2: batched graph bg = dgl.batch([g, g, g, g]) h0 = F.randn((bg.number_of_nodes(), 5)) h1 = gap(bg, h0) assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2