Exemplo n.º 1
0
def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
    ctx = F.ctx()
    adj = tf.sparse.to_dense(
        tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx)))

    conv = nn.GraphConv(5, 2, norm='none', bias=True)
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

    conv = nn.GraphConv(5, 2)
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    conv = nn.GraphConv(5, 2)
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
Exemplo n.º 2
0
def test_graph_conv2(g, norm, weight, bias):
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, 2))
    nsrc = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_src_nodes()
    ndst = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
    if weight:
        h = conv(g, h)
    else:
        h = conv(g, h, weight=ext_w)
    assert h.shape == (ndst, 2)
Exemplo n.º 3
0
Arquivo: test_nn.py Projeto: zwwlp/dgl
def test_graph_conv2(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, 2))
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
    h_dst = F.randn((ndst, 2))
    if weight:
        h_out = conv(g, h)
    else:
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)
Exemplo n.º 4
0
def test_graph_conv2(g, norm, weight, bias):
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, 2))
    nsrc = g.number_of_nodes() if isinstance(
        g, dgl.DGLGraph) else g.number_of_src_nodes()
    ndst = g.number_of_nodes() if isinstance(
        g, dgl.DGLGraph) else g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
    h_dst = F.randn((ndst, 2))
    if weight:
        h_out = conv(g, h)
    else:
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

    if not isinstance(g, dgl.DGLGraph) and len(g.ntypes) == 2:
        # bipartite, should also accept pair of tensors
        if weight:
            h_out2 = conv(g, (h, h_dst))
        else:
            h_out2 = conv(g, (h, h_dst), weight=ext_w)
        assert h_out2.shape == (ndst, 2)
        assert F.array_equal(h_out, h_out2)
Exemplo n.º 5
0
Arquivo: test_nn.py Projeto: zwwlp/dgl
def test_hetero_conv(agg, idtype):
    g = dgl.heterograph(
        {
            ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
            ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0),
                                        (2, 2)],
            ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]
        },
        idtype=idtype,
        device=F.ctx())
    conv = nn.HeteroGraphConv(
        {
            'follows': nn.GraphConv(2, 3),
            'plays': nn.GraphConv(2, 4),
            'sells': nn.GraphConv(3, 4)
        }, agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv(
        {
            'follows': nn.SAGEConv(2, 3, 'mean'),
            'plays': nn.SAGEConv((2, 4), 4, 'mean'),
            'sells': nn.SAGEConv(3, 4, 'mean')
        }, agg)

    h = conv(g, ({'user': uf}, {'user': uf, 'game': gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game': gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2

        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))

    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3
    }, agg)
    mod_args = {'follows': (1, ), 'plays': (1, )}
    mod_kwargs = {'sells': {'arg2': 'abc'}}
    h = conv(g, {
        'user': uf,
        'store': sf
    },
             mod_args=mod_args,
             mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1
Exemplo n.º 6
0
def test_hetero_conv(agg, idtype):
    g = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])
        },
        idtype=idtype,
        device=F.ctx())
    conv = nn.HeteroGraphConv(
        {
            'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
            'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
            'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)
        }, agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

    h = conv(g, {'user': uf, 'store': sf, 'game': gf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    block = dgl.to_block(g.to(F.cpu()), {
        'user': [0, 1, 2, 3],
        'game': [0, 1, 2, 3],
        'store': []
    }).to(F.ctx())
    h = conv(block, ({
        'user': uf,
        'game': gf,
        'store': sf
    }, {
        'user': uf,
        'game': gf,
        'store': sf[0:0]
    }))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2

        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))

    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3
    }, agg)
    mod_args = {'follows': (1, ), 'plays': (1, )}
    mod_kwargs = {'sells': {'arg2': 'abc'}}
    h = conv(g, {
        'user': uf,
        'game': gf,
        'store': sf
    },
             mod_args=mod_args,
             mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

    #conv on graph without any edges
    for etype in g.etypes:
        g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
    assert g.num_edges() == 0
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}

    block = dgl.to_block(g.to(F.cpu()), {
        'user': [0, 1, 2, 3],
        'game': [0, 1, 2, 3],
        'store': []
    }).to(F.ctx())
    h = conv(block, ({
        'user': uf,
        'game': gf,
        'store': sf
    }, {
        'user': uf,
        'game': gf,
        'store': sf[0:0]
    }))
    assert set(h.keys()) == {'user', 'game'}