コード例 #1
0
def test_gat_conv():
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((100, 5))
    h = gat(g, feat)
    assert h.shape == (100, 4, 2)

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    gat = nn.GATConv((5, 10), 2, 4)
    feat = (F.randn((100, 5)), F.randn((200, 10)))
    h = gat(g, feat)
コード例 #2
0
ファイル: test_nn.py プロジェクト: jermainewang/dgl
def test_gat_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATConv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    h = gat(g, feat)
コード例 #3
0
ファイル: test_nn.py プロジェクト: zwwlp/dgl
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((g.number_of_nodes(), 5))
    h = gat(g, feat)
    assert h.shape == (g.number_of_nodes(), 4, 2)
コード例 #4
0
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATConv((5, 10), 2, 4)
    feat = (F.randn(
        (g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 10)))
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
コード例 #5
0
ファイル: test_nn.py プロジェクト: zubair-ahmed-ai/dgl
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((g.number_of_nodes(), 5))
    h = gat(g, feat)
    assert h.shape == (g.number_of_nodes(), 4, 2)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
コード例 #6
0
ファイル: test_nn.py プロジェクト: weibao918/dgl
def test_gat_conv():
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((100, 5))
    h = gat(g, feat)
    assert h.shape == (100, 4, 2)

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    gat = nn.GATConv((5, 10), 2, 4)
    feat = (F.randn((100, 5)), F.randn((200, 10)))
    h = gat(g, feat)

    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = np.unique(g.edges()[1].numpy())
    block = dgl.to_block(g, seed_nodes)
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((block.number_of_src_nodes(), 5))
    h = gat(block, feat)
    assert h.shape == (block.number_of_dst_nodes(), 4, 2)
コード例 #7
0
ファイル: test_nn.py プロジェクト: jjhu94/dgl-1
def test_gat_conv():
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((100, 5))
    h = gat(g, feat)
    assert h.shape[-1] == 2 and h.shape[-2] == 4