Exemplo n.º 1
0
def test_sage_conv(aggre_type):
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 200

    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3
Exemplo n.º 2
0
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
    sage = nn.SAGEConv(5, 10, aggre_type)
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == g.number_of_dst_nodes()
Exemplo n.º 3
0
Arquivo: test_nn.py Projeto: zwwlp/dgl
def test_sage_conv_bi_empty(idtype, aggre_type):
    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3)).to(F.ctx())
    g = g.astype(idtype).to(F.ctx())
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3
Exemplo n.º 4
0
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
    # Test the case for graphs without edges
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx())
    g = g.astype(idtype).to(F.ctx())
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
    assert h.shape[-1] == out_dim
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
        assert h.shape[-1] == out_dim
        assert h.shape[0] == 3
Exemplo n.º 5
0
def test_sage_conv():
    for aggre_type in ['mean', 'pool', 'gcn', 'lstm']:
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
        sage = nn.SAGEConv(5, 10, aggre_type)
        feat = F.randn((100, 5))
        h = sage(g, feat)
        assert h.shape[-1] == 10
Exemplo n.º 6
0
def test_sage_conv(aggre_type):
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 200
Exemplo n.º 7
0
def test_sage_conv(aggre_type):
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 200

    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = np.unique(g.edges()[1].numpy())
    block = dgl.to_block(g, seed_nodes)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((block.number_of_src_nodes(), 5))
    h = sage(block, feat)
    assert h.shape[0] == block.number_of_dst_nodes()
    assert h.shape[-1] == 10

    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3
Exemplo n.º 8
0
Arquivo: test_nn.py Projeto: zwwlp/dgl
def test_hetero_conv(agg, idtype):
    g = dgl.heterograph(
        {
            ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
            ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0),
                                        (2, 2)],
            ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]
        },
        idtype=idtype,
        device=F.ctx())
    conv = nn.HeteroGraphConv(
        {
            'follows': nn.GraphConv(2, 3),
            'plays': nn.GraphConv(2, 4),
            'sells': nn.GraphConv(3, 4)
        }, agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv(
        {
            'follows': nn.SAGEConv(2, 3, 'mean'),
            'plays': nn.SAGEConv((2, 4), 4, 'mean'),
            'sells': nn.SAGEConv(3, 4, 'mean')
        }, agg)

    h = conv(g, ({'user': uf}, {'user': uf, 'game': gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game': gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2

        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))

    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3
    }, agg)
    mod_args = {'follows': (1, ), 'plays': (1, )}
    mod_kwargs = {'sells': {'arg2': 'abc'}}
    h = conv(g, {
        'user': uf,
        'store': sf
    },
             mod_args=mod_args,
             mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1
Exemplo n.º 9
0
Arquivo: test_nn.py Projeto: zwwlp/dgl
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((g.number_of_nodes(), 5))
    h = sage(g, feat)
    assert h.shape[-1] == 10