def test_feast_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = FeaStConv(16, 32, heads=2) assert conv.__repr__() == 'FeaStConv(16, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) adj = adj.sparse_resize((4, 2)) out = conv((x1, x2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6) t = '(PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() t = '(PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)
def test_heat_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index edge_attr = torch.randn((4, 2)) adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4)) node_type = torch.tensor([0, 0, 1, 2]) edge_type = torch.tensor([0, 2, 1, 2]) conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3, num_edge_types=3, edge_type_emb_dim=5, edge_dim=2, edge_attr_emb_dim=6, heads=2, concat=True) assert str(conv) == 'HEATConv(8, 16, heads=2)' out = conv(x, edge_index, node_type, edge_type, edge_attr) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out) if is_full_test(): t = '(Tensor, Tensor, Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose( jit(x, edge_index, node_type, edge_type, edge_attr), out) conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3, num_edge_types=3, edge_type_emb_dim=5, edge_dim=2, edge_attr_emb_dim=6, heads=2, concat=False) assert str(conv) == 'HEATConv(8, 16, heads=2)' out = conv(x, edge_index, node_type, edge_type, edge_attr) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out) if is_full_test(): t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out)
def test_eg_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = EGConv(16, 32) assert conv.__repr__() == "EGConv(16, 32, aggregators=['symnorm'])" out = conv(x, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv(x, edge_index).tolist() == out.tolist() conv(x, adj.t()) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
def test_dna_conv(): x = torch.randn((4, 3, 32)) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = DNAConv(32, heads=4, groups=8, dropout=0.0) assert conv.__repr__() == 'DNAConv(32, heads=4, groups=8)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv(x, edge_index).tolist() == out1.tolist() conv(x, adj1.t()) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
def test_res_gated_graph_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = ResGatedGraphConv(8, 32) assert conv.__repr__() == 'ResGatedGraphConv(8, 32)' out = conv(x1, edge_index) assert out.size() == (4, 32) assert conv(x1, adj.t()).tolist() == out.tolist() t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = ResGatedGraphConv((8, 32), 32) assert conv.__repr__() == 'ResGatedGraphConv((8, 32), 32)' out = conv((x1, x2), edge_index) assert out.size() == (2, 32) assert conv((x1, x2), adj.t()).tolist() == out.tolist() t = '(PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() t = '(PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out.tolist()
def test_gated_graph_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = GatedGraphConv(32, num_layers=3) assert conv.__repr__() == 'GatedGraphConv(32, num_layers=3)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6)
def test_gcn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = GCNConv(16, 32) assert conv.__repr__() == 'GCNConv(16, 32)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv(x, edge_index).tolist() == out1.tolist() conv(x, adj1.t()) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
def test_pna_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = PNAConv(16, 32, aggregators, scalers, deg=torch.tensor([0, 3, 0, 1]), edge_dim=3, towers=4) assert conv.__repr__() == 'PNAConv(16, 32, towers=4, edge_dim=3)' out = conv(x, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index, value), out, atol=1e-6) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
def test_lg_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = LGConv() assert str(conv) == 'LGConv()' out1 = conv(x, edge_index) assert out1.size() == (4, 8) assert torch.allclose(conv(x, adj1.t()), out1) out2 = conv(x, edge_index, value) assert out2.size() == (4, 8) assert torch.allclose(conv(x, adj2.t()), out2) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index), out1) assert torch.allclose(jit(x, edge_index, value), out2) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1) assert torch.allclose(jit(x, adj2.t()), out2)
def test_ppf_conv(): x1 = torch.randn(4, 16) pos1 = torch.randn(4, 3) pos2 = torch.randn(2, 3) n1 = F.normalize(torch.rand(4, 3), dim=-1) n2 = F.normalize(torch.rand(2, 3), dim=-1) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32)) global_nn = Seq(Lin(32, 32)) conv = PPFConv(local_nn, global_nn) assert conv.__repr__() == ( 'PPFConv(local_nn=Sequential(\n' ' (0): Linear(in_features=20, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '), global_nn=Sequential(\n' ' (0): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, pos1, n1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, n1, adj.t()), out, atol=1e-6) t = '(OptTensor, Tensor, Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, pos1, n1, edge_index).tolist() == out.tolist() t = '(OptTensor, Tensor, Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, pos1, n1, adj.t()), out, atol=1e-6) adj = adj.sparse_resize((4, 2)) out = conv(x1, (pos1, pos2), (n1, n2), edge_index) assert out.size() == (2, 32) assert conv((x1, None), (pos1, pos2), (n1, n2), edge_index).tolist() == out.tolist() assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj.t()), out, atol=1e-6) assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj.t()), out, atol=1e-6) t = '(PairOptTensor, PairTensor, PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, None), (pos1, pos2), (n1, n2), edge_index).tolist() == out.tolist() t = '(PairOptTensor, PairTensor, PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, None), (pos1, pos2), (n1, n2), adj.t()), out, atol=1e-6)
def test_cg_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = CGConv(8) assert conv.__repr__() == 'CGConv(8, dim=0)' out = conv(x1, edge_index) assert out.size() == (4, 8) assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = CGConv((8, 16)) assert conv.__repr__() == 'CGConv((8, 16), dim=0)' out = conv((x1, x2), edge_index) assert out.size() == (2, 16) assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out.tolist() assert conv((x1, x2), adj.t()).tolist() == out.tolist() t = '(PairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out.tolist() t = '(PairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out.tolist() # Test batch_norm true: adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = CGConv(8, batch_norm=True) assert conv.__repr__() == 'CGConv(8, dim=0)' out = conv(x1, edge_index) assert out.size() == (4, 8) assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
def test_point_transformer_conv(): x1 = torch.rand(4, 16) x2 = torch.randn(2, 8) pos1 = torch.rand(4, 3) pos2 = torch.randn(2, 3) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = PointTransformerConv(in_channels=16, out_channels=32) assert str(conv) == 'PointTransformerConv(16, 32)' out = conv(x1, pos1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, pos1, edge_index).tolist() == out.tolist() t = '(Tensor, Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6) pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32)) attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32)) conv = PointTransformerConv(16, 32, pos_nn, attn_nn) out = conv(x1, pos1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6) conv = PointTransformerConv((16, 8), 32) adj = adj.sparse_resize((4, 2)) out = conv((x1, x2), (pos1, pos2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), (pos1, pos2), adj.t()), out, atol=1e-6) if is_full_test(): t = '(PairTensor, PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), (pos1, pos2), edge_index).tolist() == out.tolist() t = '(PairTensor, PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), (pos1, pos2), adj.t()), out, atol=1e-6)
def test_fa_conv(): x = torch.randn(4, 16) x_0 = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = FAConv(16, eps=1.0) assert conv.__repr__() == 'FAConv(16, eps=1.0)' out = conv(x, x_0, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, x_0, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor, Tensor, OptTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, x_0, edge_index).tolist() == out.tolist() t = '(Tensor, Tensor, SparseTensor, OptTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, x_0, adj.t()), out, atol=1e-6) # Test `return_attention_weights`. result = conv(x, x_0, edge_index, return_attention_weights=True) assert result[0].tolist() == out.tolist() assert result[1][0].size() == (2, 10) assert result[1][1].size() == (10, ) assert conv._alpha is None result = conv(x, x_0, adj.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4] and result[1].nnz() == 10 assert conv._alpha is None t = ('(Tensor, Tensor, Tensor, OptTensor, bool) ' '-> Tuple[Tensor, Tuple[Tensor, Tensor]]') jit = torch.jit.script(conv.jittable(t)) result = jit(x, x_0, edge_index, return_attention_weights=True) assert result[0].tolist() == out.tolist() assert result[1][0].size() == (2, 10) assert result[1][1].size() == (10, ) assert conv._alpha is None t = ('(Tensor, Tensor, SparseTensor, OptTensor, bool) ' '-> Tuple[Tensor, SparseTensor]') jit = torch.jit.script(conv.jittable(t)) result = jit(x, x_0, adj.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4] and result[1].nnz() == 10 assert conv._alpha is None
def test_sparse_tensor_spspmm(dtype, device): x = SparseTensor( row=torch.tensor( [0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9], device=device), col=torch.tensor( [0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15], device=device), value=torch.tensor([ 1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5, -2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5, -2**-0.5, 2**-0.5, -2**-0.5 ], dtype=dtype, device=device), ) expected = torch.eye(10, dtype=dtype, device=device) out = x @ x.to_dense().t() assert torch.allclose(out, expected, atol=1e-7) out = x @ x.t() out = out.to_dense() assert torch.allclose(out, expected, atol=1e-7)
def __init__(self, edge_index: torch.Tensor, sizes: List[int], node_idx: Optional[torch.Tensor] = None, num_nodes: Optional[int] = None, flow: str = "source_to_target", **kwargs): N = int(edge_index.max() + 1) if num_nodes is None else num_nodes edge_attr = torch.arange(edge_index.size(1)) adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr, sparse_sizes=(N, N), is_sorted=False) adj = adj.t() if flow == 'source_to_target' else adj self.adj = adj.to('cpu') if node_idx is None: node_idx = torch.arange(N) elif node_idx.dtype == torch.bool: node_idx = node_idx.nonzero().view(-1) self.sizes = sizes self.flow = flow assert self.flow in ['source_to_target', 'target_to_source'] super(NeighborSampler, self).__init__(node_idx.tolist(), collate_fn=self.sample, **kwargs)
def test_rgat_conv_jittable(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index edge_attr = torch.randn((4, 8)) adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4)) edge_type = torch.tensor([0, 2, 1, 2]) conv = RGATConv(8, 20, num_relations=4, num_bases=4, mod='additive', attention_mechanism='across-relation', attention_mode='additive-self-attention', heads=2, dim=1, edge_dim=8, bias=False) out = conv(x, edge_index, edge_type, edge_attr) assert out.size() == (4, 40) assert torch.allclose(conv(x, adj.t(), edge_type), out) t = '(Tensor, Tensor, OptTensor, OptTensor, Size, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index, edge_type), conv(x, edge_index, edge_type))
def forward(self, x, edge_index, edge_weight=None): x = self.bn(x) # edge_index, edge_weight = self.adj_dropout(edge_index=edge_index, edge_weight=edge_weight) if self.use_sparse: edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight, sparse_sizes=(x.shape[0], x.shape[0])) # x = edge_index.t().to_dense()@x for i, layer in enumerate(self.gcn_layer): x = layer(x=x, edge_index=edge_index.t(), edge_weight=edge_weight) # x = self.bn_layer[i](x) x = F.relu(x, inplace=True) # x = layer(x=x, edge_index=edge_index.t(), edge_weight=edge_weight) # x = F.relu(x, inplace=True) # x = self.bns[i](x) # x = F.dropout(x, 0.1) else: for layer in self.gcn_layer: x = layer(x=x, edge_index=edge_index, edge_weight=edge_weight) x = F.relu(x, inplace=True) # x = self.bn_layer[0](x) x = self.project(x) return x
def forward(self, x, edge_index, edge_weight=None, batch=None): """""" N = x.size(0) edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value=1, num_nodes=N) if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x x_pool = x if self.GNN is not None: x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[0]] x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max') x_q = self.lin(x_q)[edge_index[1]] score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[1], num_nodes=N) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout, training=self.training) v_j = x[edge_index[0]] * score.view(-1, 1) x = scatter(v_j, edge_index[1], dim=0, reduce='add') # Cluster selection. fitness = self.gnn_score(x, edge_index).sigmoid().view(-1) perm = topk(fitness, self.ratio, batch) x = x[perm] * fitness[perm].view(-1, 1) batch = batch[perm] # Graph coarsening. row, col = edge_index A = SparseTensor(row=row, col=col, value=edge_weight, sparse_sizes=(N, N)) S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N)) S = S[:, perm] A = S.t() @ A @ S if self.add_self_loops: A = A.fill_diag(1.) else: A = A.remove_diag() row, col, edge_weight = A.coo() edge_index = torch.stack([row, col], dim=0) return x, edge_index, edge_weight, batch, perm
def test_gen_conv(aggr): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.randn(row.size(0), 16) adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = GENConv(16, 32, aggr) assert conv.__repr__() == f'GENConv(16, 32, aggr={aggr})' out11 = conv(x1, edge_index) assert out11.size() == (4, 32) assert conv(x1, edge_index, size=(4, 4)).tolist() == out11.tolist() assert conv(x1, adj1.t()).tolist() == out11.tolist() out12 = conv(x1, edge_index, value) assert out12.size() == (4, 32) assert conv(x1, edge_index, value, (4, 4)).tolist() == out12.tolist() assert conv(x1, adj2.t()).tolist() == out12.tolist() t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out11.tolist() assert jit(x1, edge_index, size=(4, 4)).tolist() == out11.tolist() assert jit(x1, edge_index, value).tolist() == out12.tolist() assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out12.tolist() t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj1.t()).tolist() == out11.tolist() assert jit(x1, adj2.t()).tolist() == out12.tolist() adj1 = adj1.sparse_resize((4, 2)) adj2 = adj2.sparse_resize((4, 2)) out21 = conv((x1, x2), edge_index) assert out21.size() == (2, 32) assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out21.tolist() assert conv((x1, x2), adj1.t()).tolist() == out21.tolist() out22 = conv((x1, x2), edge_index, value) assert out22.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out22.tolist() assert conv((x1, x2), adj2.t()).tolist() == out22.tolist() t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out21.tolist() assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out21.tolist() assert jit((x1, x2), edge_index, value).tolist() == out22.tolist() assert jit((x1, x2), edge_index, value, (4, 2)).tolist() == out22.tolist() t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj1.t()).tolist() == out21.tolist() assert jit((x1, x2), adj2.t()).tolist() == out22.tolist()
def test_my_default_arg_conv(): x = torch.randn(4, 1) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = MyDefaultArgConv() assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] assert conv(x, adj.t()).view(-1).tolist() == [0, 0, 0, 0]
def test_agnn_conv(requires_grad): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = AGNNConv(requires_grad=requires_grad) assert conv.__repr__() == 'AGNNConv()' out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
def test_cluster_gcn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = ClusterGCNConv(16, 32, diag_lambda=1.) assert conv.__repr__() == 'ClusterGCNConv(16, 32, diag_lambda=1.0)' out = conv(x, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t()), out, atol=1e-5) t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-5)
def test_arma_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = ARMAConv(16, 32, num_stacks=8, num_layers=4) assert conv.__repr__() == 'ARMAConv(16, 32, num_stacks=8, num_layers=4)' out = conv(x, edge_index) assert out.size() == (4, 32) assert conv(x, adj.t()).tolist() == out.tolist() t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)
def test_appnp(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = APPNP(K=10, alpha=0.1) assert conv.__repr__() == 'APPNP(K=10, alpha=0.1)' out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
def test_message_passing_with_aggr_module(aggr_module): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = MyAggregatorConv(aggr=aggr_module) assert isinstance(conv.aggr_module, aggr.Aggregation) out = conv(x, edge_index) assert out.size(0) == 4 and out.size(1) in {8, 16} assert torch.allclose(conv(x, adj.t()), out)
def test_pdn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index edge_attr = torch.randn(6, 8) adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4)) conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128) assert str(conv) == "PDNConv(16, 32)" out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index, edge_attr), out) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
def test_my_edge_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='add') conv = MyEdgeConv() out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(out, expected) assert torch.allclose(conv(x, adj.t()), out) t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index), expected) t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), expected)
def test_my_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.randn(row.size(0)) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = MyConv(8, 32) out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() conv.fuse = False assert conv(x1, adj.t()).tolist() == out.tolist() conv.fuse = True adj = adj.sparse_resize((4, 2)) conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() conv.fuse = False assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() conv.fuse = True
def test_cg_conv_with_edge_features(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = CGConv(8, dim=3) assert conv.__repr__() == 'CGConv(8, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 8) assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index, value).tolist() == out.tolist() assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = CGConv((8, 16), dim=3) assert conv.__repr__() == 'CGConv((8, 16), dim=3)' out = conv((x1, x2), edge_index, value) assert out.size() == (2, 16) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out.tolist() assert conv((x1, x2), adj.t()).tolist() == out.tolist() t = '(PairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index, value).tolist() == out.tolist() assert jit((x1, x2), edge_index, value, (4, 2)).tolist() == out.tolist() t = '(PairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out.tolist()
def test_nn_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32)) conv = NNConv(8, 32, nn=nn) assert conv.__repr__() == ( 'NNConv(8, 32, aggr=add, nn=Sequential(\n' ' (0): Linear(in_features=3, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=256, bias=True)\n' '))') out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index, value).tolist() == out.tolist() assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = NNConv((8, 16), 32, nn=nn) assert conv.__repr__() == ( 'NNConv((8, 16), 32, aggr=add, nn=Sequential(\n' ' (0): Linear(in_features=3, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=256, bias=True)\n' '))') out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index, value).tolist() == out1.tolist() assert jit((x1, x2), edge_index, value, size=(4, 2)).tolist() == out1.tolist() assert jit((x1, None), edge_index, value, size=(4, 2)).tolist() == out2.tolist() t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out1.tolist() assert jit((x1, None), adj.t()).tolist() == out2.tolist()