def test_gcn_dense(self, device, gcn, adj, node_feat_in, expected): dense_adj = torch.sparse.mm(adj, torch.eye(3, device=device)) node_feat_out = gcn(node_feat_in, dense_adj, normalize_adj=True) assert torch.allclose(node_feat_out, expected) dense_adj = normalize_adj(dense_adj) node_feat_out_2 = gcn(node_feat_in, dense_adj, normalize_adj=False) assert torch.allclose(node_feat_out, node_feat_out_2)
def test_normalize_adj_sparse(self, device, adj): result = normalize_adj(adj) norm = torch.sparse.mm(adj, torch.ones((adj.shape[0], 1), device=device)) expected = torch.sparse.mm(adj, torch.eye(3, device=device)) / norm assert torch.allclose( torch.sparse.mm(result, torch.eye(3, device=device)), expected)
def test_gcn_sparse(self, device, gcn, adj, node_feat_in, expected): node_feat_out = gcn(node_feat_in, adj, normalize_adj=True) assert torch.allclose(node_feat_out, expected, rtol=1e-3, atol=1e-3) adj = normalize_adj(adj) node_feat_out_2 = gcn(node_feat_in, adj, normalize_adj=False) assert torch.allclose(node_feat_out, node_feat_out_2, rtol=1e-4, atol=1e-4)
def test_normalize_adj_dense(self, device, adj): dense_adj = torch.sparse.mm(adj, torch.eye(3, device=device)) result = normalize_adj(dense_adj) expected = torch.tensor( [[0.0, 0.14285714, 0.85714285], [0.4, 0.0, 0.6], [0.55555555, 0.44444444, 0.0]], device=device) assert torch.allclose(result, expected)