Exemplo n.º 1
0
def test_sage_conv():
    for aggre_type in ['mean', 'pool', 'gcn']:
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1),
                         readonly=True)
        sage = nn.SAGEConv(5, 10, aggre_type)
        feat = F.randn((100, 5))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 10

        g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
        sage = nn.SAGEConv(5, 10, aggre_type)
        feat = F.randn((100, 5))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 10

        g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
        dst_dim = 5 if aggre_type != 'gcn' else 10
        sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
        feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 200
Exemplo n.º 2
0
def test_sage_conv(idtype, g, aggre_type, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    sage = nn.SAGEConv(5, out_dim, aggre_type)
    feat = F.randn((g.number_of_src_nodes(), 5))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == out_dim
Exemplo n.º 3
0
def test_sage_conv_bi2(idtype, aggre_type, out_dim):
    # Test the case for graphs without edges
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == out_dim
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool']:
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == out_dim
        assert h.shape[0] == 3
Exemplo n.º 4
0
def test_sage_conv_bi2(idtype, aggre_type):
    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3
Exemplo n.º 5
0
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == out_dim
    assert h.shape[0] == g.number_of_dst_nodes()
Exemplo n.º 6
0
def test_sage_conv():
    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
    ctx = F.ctx()

    graphsage = nn.SAGEConv(10, 20)
    graphsage.initialize(ctx=ctx)
    print(graphsage)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = graphsage(g, h0)
    assert h1.shape == (20, 20)
Exemplo n.º 7
0
def test_sage_conv(aggre_type):
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 200

    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3
Exemplo n.º 8
0
def test_dense_sage_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    adj = g.adjacency_matrix(ctx=ctx).tostype('default')
    sage = nn.SAGEConv(5, 2, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, 2)
    sage.initialize(ctx=ctx)
    dense_sage.initialize(ctx=ctx)
    dense_sage.fc.weight.set_data(sage.fc_neigh.weight.data())
    dense_sage.fc.bias.set_data(sage.fc_neigh.bias.data())
    feat = F.randn((100, 5))

    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage)
Exemplo n.º 9
0
def test_dense_sage_conv(g):
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).tostype('default')
    sage = nn.SAGEConv(5, 2, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, 2)
    sage.initialize(ctx=ctx)
    dense_sage.initialize(ctx=ctx)
    dense_sage.fc.weight.set_data(sage.fc_neigh.weight.data())
    dense_sage.fc.bias.set_data(sage.fc_neigh.bias.data())
    if len(g.ntypes) == 2:
        feat = (F.randn(
            (g.number_of_src_nodes(), 5)), F.randn(
                (g.number_of_dst_nodes(), 5)))
    else:
        feat = F.randn((g.number_of_nodes(), 5))

    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage)
Exemplo n.º 10
0
def test_sage_conv(aggre_type):
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 200

    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = np.unique(g.edges()[1].asnumpy())
    block = dgl.to_block(g, seed_nodes)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((block.number_of_src_nodes(), 5))
    sage.initialize(ctx=ctx)
    h = sage(block, feat)
    assert h.shape[0] == block.number_of_dst_nodes()
    assert h.shape[-1] == 10

    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3
Exemplo n.º 11
0
def test_hetero_conv(agg):
    g = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
        ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
        ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]
    })
    conv = nn.HeteroGraphConv(
        {
            'follows': nn.GraphConv(2, 3),
            'plays': nn.GraphConv(2, 4),
            'sells': nn.GraphConv(3, 4)
        }, agg)
    conv.initialize(ctx=F.ctx())
    print(conv)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv(
        {
            'follows': nn.SAGEConv(2, 3, 'mean'),
            'plays': nn.SAGEConv((2, 4), 4, 'mean'),
            'sells': nn.SAGEConv(3, 4, 'mean')
        }, agg)
    conv.initialize(ctx=F.ctx())

    h = conv(g, ({'user': uf}, {'user': uf, 'game': gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game': gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(mx.gluon.nn.Block):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.s1 = s1
            self.s2 = s2

        def forward(self, g, h, arg1=None):  # mxnet does not support kwargs
            if arg1 is not None:
                self.carg1 += 1
            return F.zeros((g.number_of_dst_nodes(), self.s2))

    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3
    }, agg)
    conv.initialize(ctx=F.ctx())
    mod_args = {'follows': (1, ), 'plays': (1, )}
    h = conv(g, {'user': uf, 'store': sf}, mod_args)
    assert mod1.carg1 == 1
    assert mod2.carg1 == 1
    assert mod3.carg1 == 0