def test_dense_sage_conv(): ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) adj = g.adjacency_matrix(ctx=ctx).to_dense() sage = nn.SAGEConv(5, 2, 'gcn') dense_sage = nn.DenseSAGEConv(5, 2) dense_sage.fc.weight.data = sage.fc_neigh.weight.data dense_sage.fc.bias.data = sage.fc_neigh.bias.data feat = F.randn((100, 5)) sage = sage.to(ctx) dense_sage = dense_sage.to(ctx) out_sage = sage(g, feat) out_dense_sage = dense_sage(adj, feat) assert F.allclose(out_sage, out_dense_sage)
def test_dense_sage_conv(g): ctx = F.ctx() adj = g.adjacency_matrix(ctx=ctx).to_dense() sage = nn.SAGEConv(5, 2, 'gcn') dense_sage = nn.DenseSAGEConv(5, 2) dense_sage.fc.weight.data = sage.fc_neigh.weight.data dense_sage.fc.bias.data = sage.fc_neigh.bias.data if len(g.ntypes) == 2: feat = (F.randn( (g.number_of_src_nodes(), 5)), F.randn( (g.number_of_dst_nodes(), 5))) else: feat = F.randn((g.number_of_nodes(), 5)) sage = sage.to(ctx) dense_sage = dense_sage.to(ctx) out_sage = sage(g, feat) out_dense_sage = dense_sage(adj, feat) assert F.allclose(out_sage, out_dense_sage), g