Ejemplo n.º 1
0
def test_dotgat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    dotgat = nn.DotGatConv((5, 5), 2, 4)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
    _, a = dotgat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)    
Ejemplo n.º 2
0
def test_dotgat_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_nodes(), 5))
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
    assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
    _, a = dotgat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)