Exemple #1
0
def test_gmm_conv():
    ctx = F.ctx()

    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_nodes(), 2)

    g = dgl.graph(nx.erdos_renyi_graph(20, 0.3))
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_nodes(), 2)

    g = dgl.bipartite(sp.sparse.random(20, 10, 0.1))
    gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, (h0, hd), pseudo)
    assert h1.shape == (g.number_of_dst_nodes(), 2)
Exemple #2
0
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    h0 = F.randn((g.number_of_src_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_dst_nodes(), 2)
Exemple #3
0
def test_gmm_conv():
    ctx = F.ctx()

    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_nodes(), 2)

    g = dgl.graph(nx.erdos_renyi_graph(20, 0.3))
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_nodes(), 2)

    g = dgl.bipartite(sp.sparse.random(20, 10, 0.1))
    gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, (h0, hd), pseudo)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = np.unique(g.edges()[1].asnumpy())
    block = dgl.to_block(g, seed_nodes)
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'mean')
    gmm_conv.initialize(ctx=ctx)
    h0 = F.randn((block.number_of_src_nodes(), 5))
    pseudo = F.randn((block.number_of_edges(), 5))
    h = gmm_conv(block, h0, pseudo)
    assert h.shape[0] == block.number_of_dst_nodes()
    assert h.shape[-1] == 2
Exemple #4
0
def test_gmm_conv():
    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
    ctx = F.ctx()

    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    print(gmm_conv)

    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_nodes(), 2)