def test_sage_conv(): for aggre_type in ['mean', 'pool', 'gcn']: ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((100, 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 10 g = dgl.graph(sp.sparse.random(100, 100, density=0.1)) sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((100, 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 10 g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1)) dst_dim = 5 if aggre_type != 'gcn' else 10 sage = nn.SAGEConv((10, dst_dim), 2, aggre_type) feat = (F.randn((100, 10)), F.randn((200, dst_dim))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 200
def test_sage_conv(idtype, g, aggre_type, out_dim): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() sage = nn.SAGEConv(5, out_dim, aggre_type) feat = F.randn((g.number_of_src_nodes(), 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == out_dim
def test_sage_conv_bi2(idtype, aggre_type, out_dim): # Test the case for graphs without edges g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}) g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() sage = nn.SAGEConv((3, 3), out_dim, 'gcn') feat = (F.randn((5, 3)), F.randn((3, 3))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == out_dim assert h.shape[0] == 3 for aggre_type in ['mean', 'pool']: sage = nn.SAGEConv((3, 1), out_dim, aggre_type) feat = (F.randn((5, 3)), F.randn((3, 1))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == out_dim assert h.shape[0] == 3
def test_sage_conv_bi2(idtype, aggre_type): # Test the case for graphs without edges g = dgl.bipartite([], num_nodes=(5, 3)) g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() sage = nn.SAGEConv((3, 3), 2, 'gcn') feat = (F.randn((5, 3)), F.randn((3, 3))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 3 for aggre_type in ['mean', 'pool']: sage = nn.SAGEConv((3, 1), 2, aggre_type) feat = (F.randn((5, 3)), F.randn((3, 1))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 3
def test_sage_conv_bi(idtype, g, aggre_type, out_dim): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() dst_dim = 5 if aggre_type != 'gcn' else 10 sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type) feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == out_dim assert h.shape[0] == g.number_of_dst_nodes()
def test_sage_conv(): g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) ctx = F.ctx() graphsage = nn.SAGEConv(10, 20) graphsage.initialize(ctx=ctx) print(graphsage) # test#1: basic h0 = F.randn((20, 10)) h1 = graphsage(g, h0) assert h1.shape == (20, 20)
def test_sage_conv(aggre_type): ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((100, 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 10 g = dgl.graph(sp.sparse.random(100, 100, density=0.1)) sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((100, 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 10 g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1)) dst_dim = 5 if aggre_type != 'gcn' else 10 sage = nn.SAGEConv((10, dst_dim), 2, aggre_type) feat = (F.randn((100, 10)), F.randn((200, dst_dim))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 200 # Test the case for graphs without edges g = dgl.bipartite([], num_nodes=(5, 3)) sage = nn.SAGEConv((3, 3), 2, 'gcn') feat = (F.randn((5, 3)), F.randn((3, 3))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 3 for aggre_type in ['mean', 'pool']: sage = nn.SAGEConv((3, 1), 2, aggre_type) feat = (F.randn((5, 3)), F.randn((3, 1))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 3
def test_dense_sage_conv(): ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) adj = g.adjacency_matrix(ctx=ctx).tostype('default') sage = nn.SAGEConv(5, 2, 'gcn') dense_sage = nn.DenseSAGEConv(5, 2) sage.initialize(ctx=ctx) dense_sage.initialize(ctx=ctx) dense_sage.fc.weight.set_data(sage.fc_neigh.weight.data()) dense_sage.fc.bias.set_data(sage.fc_neigh.bias.data()) feat = F.randn((100, 5)) out_sage = sage(g, feat) out_dense_sage = dense_sage(adj, feat) assert F.allclose(out_sage, out_dense_sage)
def test_dense_sage_conv(g): ctx = F.ctx() adj = g.adjacency_matrix(ctx=ctx).tostype('default') sage = nn.SAGEConv(5, 2, 'gcn') dense_sage = nn.DenseSAGEConv(5, 2) sage.initialize(ctx=ctx) dense_sage.initialize(ctx=ctx) dense_sage.fc.weight.set_data(sage.fc_neigh.weight.data()) dense_sage.fc.bias.set_data(sage.fc_neigh.bias.data()) if len(g.ntypes) == 2: feat = (F.randn( (g.number_of_src_nodes(), 5)), F.randn( (g.number_of_dst_nodes(), 5))) else: feat = F.randn((g.number_of_nodes(), 5)) out_sage = sage(g, feat) out_dense_sage = dense_sage(adj, feat) assert F.allclose(out_sage, out_dense_sage)
def test_sage_conv(aggre_type): ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((100, 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 10 g = dgl.graph(sp.sparse.random(100, 100, density=0.1)) sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((100, 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 10 g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1)) dst_dim = 5 if aggre_type != 'gcn' else 10 sage = nn.SAGEConv((10, dst_dim), 2, aggre_type) feat = (F.randn((100, 10)), F.randn((200, dst_dim))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 200 g = dgl.graph(sp.sparse.random(100, 100, density=0.001)) seed_nodes = np.unique(g.edges()[1].asnumpy()) block = dgl.to_block(g, seed_nodes) sage = nn.SAGEConv(5, 10, aggre_type) feat = F.randn((block.number_of_src_nodes(), 5)) sage.initialize(ctx=ctx) h = sage(block, feat) assert h.shape[0] == block.number_of_dst_nodes() assert h.shape[-1] == 10 # Test the case for graphs without edges g = dgl.bipartite([], num_nodes=(5, 3)) sage = nn.SAGEConv((3, 3), 2, 'gcn') feat = (F.randn((5, 3)), F.randn((3, 3))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 3 for aggre_type in ['mean', 'pool']: sage = nn.SAGEConv((3, 1), 2, aggre_type) feat = (F.randn((5, 3)), F.randn((3, 1))) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == 2 assert h.shape[0] == 3
def test_hetero_conv(agg): g = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)], ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)], ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)] }) conv = nn.HeteroGraphConv( { 'follows': nn.GraphConv(2, 3), 'plays': nn.GraphConv(2, 4), 'sells': nn.GraphConv(3, 4) }, agg) conv.initialize(ctx=F.ctx()) print(conv) uf = F.randn((4, 2)) gf = F.randn((4, 4)) sf = F.randn((2, 3)) uf_dst = F.randn((4, 3)) gf_dst = F.randn((4, 4)) h = conv(g, {'user': uf}) assert set(h.keys()) == {'user', 'game'} if agg != 'stack': assert h['user'].shape == (4, 3) assert h['game'].shape == (4, 4) else: assert h['user'].shape == (4, 1, 3) assert h['game'].shape == (4, 1, 4) h = conv(g, {'user': uf, 'store': sf}) assert set(h.keys()) == {'user', 'game'} if agg != 'stack': assert h['user'].shape == (4, 3) assert h['game'].shape == (4, 4) else: assert h['user'].shape == (4, 1, 3) assert h['game'].shape == (4, 2, 4) h = conv(g, {'store': sf}) assert set(h.keys()) == {'game'} if agg != 'stack': assert h['game'].shape == (4, 4) else: assert h['game'].shape == (4, 1, 4) # test with pair input conv = nn.HeteroGraphConv( { 'follows': nn.SAGEConv(2, 3, 'mean'), 'plays': nn.SAGEConv((2, 4), 4, 'mean'), 'sells': nn.SAGEConv(3, 4, 'mean') }, agg) conv.initialize(ctx=F.ctx()) h = conv(g, ({'user': uf}, {'user': uf, 'game': gf})) assert set(h.keys()) == {'user', 'game'} if agg != 'stack': assert h['user'].shape == (4, 3) assert h['game'].shape == (4, 4) else: assert h['user'].shape == (4, 1, 3) assert h['game'].shape == (4, 1, 4) # pair input requires both src and dst type features to be provided h = conv(g, ({'user': uf}, {'game': gf})) assert set(h.keys()) == {'game'} if agg != 'stack': assert h['game'].shape == (4, 4) else: assert h['game'].shape == (4, 1, 4) # test with mod args class MyMod(mx.gluon.nn.Block): def __init__(self, s1, s2): super(MyMod, self).__init__() self.carg1 = 0 self.s1 = s1 self.s2 = s2 def forward(self, g, h, arg1=None): # mxnet does not support kwargs if arg1 is not None: self.carg1 += 1 return F.zeros((g.number_of_dst_nodes(), self.s2)) mod1 = MyMod(2, 3) mod2 = MyMod(2, 4) mod3 = MyMod(3, 4) conv = nn.HeteroGraphConv({ 'follows': mod1, 'plays': mod2, 'sells': mod3 }, agg) conv.initialize(ctx=F.ctx()) mod_args = {'follows': (1, ), 'plays': (1, )} h = conv(g, {'user': uf, 'store': sf}, mod_args) assert mod1.carg1 == 1 assert mod2.carg1 == 1 assert mod3.carg1 == 0