def test_hetero_graph_split(self): # directed G G = generate_dense_hete_graph() hete = HeteroGraph(G) hete = HeteroGraph( node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, ) # node hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.edge_label_index[edge_type].shape[1]) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): num_edges = val.shape[1] edge_0 = int(0.5 * num_edges) edge_1 = int(0.3 * num_edges) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0) self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1) self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2) # undirected G G = generate_dense_hete_graph(directed=False) hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=False) # node hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.num_edges(edge_type)) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.num_edges(edge_type)) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): num_edges = int(val.shape[1] / 2) edge_0 = 2 * int(0.5 * num_edges) edge_1 = 2 * int(0.3 * num_edges) edge_2 = 2 * (num_edges - int(0.5 * num_edges) - int(0.3 * num_edges)) self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0) self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1) self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2)
def test_hetero_graph_split(self): G = generate_dense_hete_graph() hete = HeteroGraph(G) hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual(len(hete_node[0].node_label_index[node_type]), node_0) self.assertEqual(len(hete_node[1].node_label_index[node_type]), node_1) self.assertEqual(len(hete_node[2].node_label_index[node_type]), node_2) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if (node_type in node_split_types): num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual(len(hete_node[0].node_label_index[node_type]), node_0) self.assertEqual(len(hete_node[1].node_label_index[node_type]), node_1) self.assertEqual(len(hete_node[2].node_label_index[node_type]), node_2) else: self.assertEqual(len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type])) self.assertEqual(len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type])) self.assertEqual(len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type])) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual(hete_edge[0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual(hete_edge[1].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual(hete_edge[2].edge_label_index[edge_type].shape[1], edge_2) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if (edge_type in edge_split_types): num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1]) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1]) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1]) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) # calculate the expected edge num for each splitted subgraph hete_link_train_edge_num, hete_link_val_edge_num, hete_link_test_edge_num = 0, 0, 0 for key, val in hete.edge_label_index.items(): val_length = val.shape[1] val_length_reduced = val_length - 3 hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced) hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced) hete_link_test_edge_num += \ val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced) self.assertEqual(len(hete_link[0].edge_label), hete_link_train_edge_num) self.assertEqual(len(hete_link[1].edge_label), hete_link_val_edge_num) self.assertEqual(len(hete_link[2].edge_label), hete_link_test_edge_num)
def test_hetero_graph_split(self): # directed G G = generate_dense_hete_graph() hete = HeteroGraph(G) hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): val_length = val.shape[1] val_length_reduced = val_length - 3 hete_link_train_edge_num = 1 + int(0.5 * val_length_reduced) hete_link_val_edge_num = 1 + int(0.3 * val_length_reduced) hete_link_test_edge_num = (val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced)) self.assertEqual(hete_link[0].edge_label[key].shape[0], hete_link_train_edge_num) self.assertEqual(hete_link[1].edge_label[key].shape[0], hete_link_val_edge_num) self.assertEqual(hete_link[2].edge_label[key].shape[0], hete_link_test_edge_num) # undirected G G = generate_dense_hete_graph(directed=False) hete = HeteroGraph(G) hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): val_length = val.shape[1] hete_link_train_edge_num = (2 * (1 + int(0.5 * (int(val_length / 2) - 3)))) hete_link_val_edge_num = (2 * (1 + int(0.3 * (int(val_length / 2) - 3)))) hete_link_test_edge_num = (val_length - hete_link_train_edge_num - hete_link_val_edge_num) self.assertEqual(hete_link[0].edge_label[key].shape[0], hete_link_train_edge_num) self.assertEqual(hete_link[1].edge_label[key].shape[0], hete_link_val_edge_num) self.assertEqual(hete_link[2].edge_label[key].shape[0], hete_link_test_edge_num)