def test_gatv2_conv_bi(g, idtype, out_dim, num_heads): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() gat = nn.GATv2Conv(5, out_dim, num_heads) feat = (F.randn( (g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5))) gat = gat.to(ctx) h = gat(g, feat) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) _, a = gat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1)
def test_gatv2_conv(g, idtype, out_dim, num_heads): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() gat = nn.GATv2Conv(5, out_dim, num_heads) feat = F.randn((g.number_of_src_nodes(), 5)) gat = gat.to(ctx) h = gat(g, feat) # test pickle th.save(gat, tmp_buffer) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) _, a = gat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1) # test residual connection gat = nn.GATConv(5, out_dim, num_heads, residual=True) gat = gat.to(ctx) h = gat(g, feat)