def test_sgc_conv(g, idtype): ctx = F.ctx() g = g.astype(idtype).to(ctx) # not cached sgc = nn.SGConv(5, 10, 3) feat = F.randn((g.number_of_nodes(), 5)) h = sgc(g, feat) assert h.shape[-1] == 10 # cached sgc = nn.SGConv(5, 10, 3, True) h_0 = sgc(g, feat) h_1 = sgc(g, feat + 1) assert F.allclose(h_0, h_1) assert h_0.shape[-1] == 10
def test_sgc_conv(): ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) # not cached sgc = nn.SGConv(5, 10, 3) feat = F.randn((100, 5)) h = sgc(g, feat) assert h.shape[-1] == 10 # cached sgc = nn.SGConv(5, 10, 3, True) h_0 = sgc(g, feat) h_1 = sgc(g, feat + 1) assert F.allclose(h_0, h_1) assert h_0.shape[-1] == 10