def test_to_homogeneous(): data = HeteroData() data['paper'].x = torch.randn(100, 128) data['author'].x = torch.randn(200, 128) data['paper', 'paper'].edge_index = get_edge_index(100, 100, 250) data['paper', 'paper'].edge_weight = torch.randn(250, ) data['paper', 'paper'].edge_attr = torch.randn(250, 64) data['paper', 'author'].edge_index = get_edge_index(100, 200, 500) data['paper', 'author'].edge_weight = torch.randn(500, ) data['paper', 'author'].edge_attr = torch.randn(500, 64) data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000) data['author', 'paper'].edge_weight = torch.randn(1000, ) data['author', 'paper'].edge_attr = torch.randn(1000, 64) data = data.to_homogeneous() assert len(data) == 5 assert data.num_nodes == 300 assert data.num_edges == 1750 assert data.num_node_features == 128 assert data.num_edge_features == 64 assert data.edge_type.size() == (1750, ) assert data.edge_type.min() == 0 assert data.edge_type.max() == 2 assert len(data._node_slices) == 2 assert len(data._edge_slices) == 3 assert len(data._edge_type_dict) == 3
def test_to_homogeneous_and_vice_versa(): data = HeteroData() data['paper'].x = torch.randn(100, 128) data['author'].x = torch.randn(200, 128) data['paper', 'paper'].edge_index = get_edge_index(100, 100, 250) data['paper', 'paper'].edge_weight = torch.randn(250, ) data['paper', 'paper'].edge_attr = torch.randn(250, 64) data['paper', 'author'].edge_index = get_edge_index(100, 200, 500) data['paper', 'author'].edge_weight = torch.randn(500, ) data['paper', 'author'].edge_attr = torch.randn(500, 64) data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000) data['author', 'paper'].edge_weight = torch.randn(1000, ) data['author', 'paper'].edge_attr = torch.randn(1000, 64) out = data.to_homogeneous() assert len(out) == 6 assert out.num_nodes == 300 assert out.num_edges == 1750 assert out.num_node_features == 128 assert out.num_edge_features == 64 assert out.node_type.size() == (300, ) assert out.node_type.min() == 0 assert out.node_type.max() == 1 assert out.edge_type.size() == (1750, ) assert out.edge_type.min() == 0 assert out.edge_type.max() == 2 assert len(out._node_type_names) == 2 assert len(out._edge_type_names) == 3 out = out.to_heterogeneous() assert len(out) == 4 assert torch.allclose(data['paper'].x, out['paper'].x) assert torch.allclose(data['author'].x, out['author'].x) edge_index1 = data['paper', 'paper'].edge_index edge_index2 = out['paper', 'paper'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'paper'].edge_weight, out['paper', 'paper'].edge_weight, ) assert torch.allclose( data['paper', 'paper'].edge_attr, out['paper', 'paper'].edge_attr, ) edge_index1 = data['paper', 'author'].edge_index edge_index2 = out['paper', 'author'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'author'].edge_weight, out['paper', 'author'].edge_weight, ) assert torch.allclose( data['paper', 'author'].edge_attr, out['paper', 'author'].edge_attr, ) edge_index1 = data['author', 'paper'].edge_index edge_index2 = out['author', 'paper'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['author', 'paper'].edge_weight, out['author', 'paper'].edge_weight, ) assert torch.allclose( data['author', 'paper'].edge_attr, out['author', 'paper'].edge_attr, ) out = data.to_homogeneous() node_type = out.node_type edge_type = out.edge_type del out.node_type del out.edge_type del out._edge_type_names del out._node_type_names out = out.to_heterogeneous(node_type, edge_type) assert len(out) == 4 assert torch.allclose(data['paper'].x, out['0'].x) assert torch.allclose(data['author'].x, out['1'].x) edge_index1 = data['paper', 'paper'].edge_index edge_index2 = out['0', '0'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'paper'].edge_weight, out['0', '0'].edge_weight, ) assert torch.allclose( data['paper', 'paper'].edge_attr, out['0', '0'].edge_attr, ) edge_index1 = data['paper', 'author'].edge_index edge_index2 = out['0', '1'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'author'].edge_weight, out['0', '1'].edge_weight, ) assert torch.allclose( data['paper', 'author'].edge_attr, out['0', '1'].edge_attr, ) edge_index1 = data['author', 'paper'].edge_index edge_index2 = out['1', '0'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['author', 'paper'].edge_weight, out['1', '0'].edge_weight, ) assert torch.allclose( data['author', 'paper'].edge_attr, out['1', '0'].edge_attr, ) data = HeteroData() data['paper'].num_nodes = 100 data['author'].num_nodes = 200 out = data.to_homogeneous(add_node_type=False) assert len(out) == 1 assert out.num_nodes == 300 out = data.to_homogeneous().to_heterogeneous() assert len(out) == 1 assert out['paper'].num_nodes == 100 assert out['author'].num_nodes == 200