def test_hetero_graph_basics(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) self.assertEqual(hete.num_node_features('n1'), 10) self.assertEqual(hete.num_node_features('n2'), 12) self.assertEqual(hete.num_edge_features(('n1', 'e1', 'n1')), 8) self.assertEqual(hete.num_edge_features(('n1', 'e2', 'n2')), 12) self.assertEqual(hete.num_nodes('n1'), 4) self.assertEqual(hete.num_nodes('n2'), 5) self.assertEqual(len(hete.node_types), 2) self.assertEqual(len(hete.edge_types), 2) message_types = hete.message_types self.assertEqual(len(message_types), 7) self.assertEqual(hete.num_node_labels('n1'), 2) self.assertEqual(hete.num_node_labels('n2'), 2) self.assertEqual(hete.num_edge_labels(('n1', 'e1', 'n1')), 2) self.assertEqual(hete.num_edge_labels(('n1', 'e2', 'n2')), 2) self.assertEqual(hete.num_edges(message_types[0]), 3) self.assertEqual(len(hete.node_label_index), 2)
def test_resample_disjoint_heterogeneous(self): G = generate_dense_hete_dataset() hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) graphs = [hete] dataset = GraphDataset(graphs, task="link_pred", edge_train_mode="disjoint", edge_message_ratio=0.8, resample_disjoint=True, resample_disjoint_period=1) dataset_train, _, _ = dataset.split(split_ratio=[0.5, 0.2, 0.3]) graph_train_first = dataset_train[0] graph_train_second = dataset_train[0] for message_type in graph_train_first.edge_index: self.assertEqual( graph_train_first.edge_label_index[message_type].shape[1], graph_train_second.edge_label_index[message_type].shape[1]) self.assertEqual(graph_train_first.edge_label[message_type].shape, graph_train_second.edge_label[message_type].shape)
def test_hetero_graph_batch(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) hete = HeteroGraph( node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True ) heteGraphDataset = [] for _ in range(30): heteGraphDataset.append(hete.clone()) dataloader = DataLoader( heteGraphDataset, collate_fn=Batch.collate(), batch_size=3, shuffle=True, ) self.assertEqual(len(dataloader), math.ceil(30 / 3)) for data in dataloader: self.assertEqual(data.num_graphs, 3)
def _custom_hete_split_link_pred_disjoint(self, graph_train): objective_edges = graph_train.disjoint_split nodes_dict = {} for node in graph_train.G.nodes(data=True): nodes_dict[node[0]] = node[1]["node_type"] edges_dict = {} objective_edges_dict = {} for edge in graph_train.G.edges(data=True): edge_type = edge[-1]["edge_type"] head_type = nodes_dict[edge[0]] tail_type = nodes_dict[edge[1]] message_type = (head_type, edge_type, tail_type) if message_type not in edges_dict: edges_dict[message_type] = [] edges_dict[message_type].append(edge) for edge in objective_edges: edge_type = edge[-1]["edge_type"] head_type = nodes_dict[edge[0]] tail_type = nodes_dict[edge[1]] message_type = (head_type, edge_type, tail_type) if message_type not in objective_edges_dict: objective_edges_dict[message_type] = [] objective_edges_dict[message_type].append(edge) message_edges = [] for edge_type in edges_dict: if edge_type in objective_edges_dict: edges_no_info = [edge[:-1] for edge in edges_dict[edge_type]] objective_edges_no_info = [ edge[:-1] for edge in objective_edges_dict[edge_type] ] message_edges_no_info = set(edges_no_info) - set( objective_edges_no_info) message_edges += [(edge[0], edge[1], graph_train.G.edges[edge[0], edge[1]]) for edge in message_edges_no_info] else: message_edges += edges_dict[edge_type] # update objective edges for edge_type in edges_dict: if edge_type not in objective_edges_dict: objective_edges += edges_dict[edge_type] graph_train = HeteroGraph(graph_train._edge_subgraph_with_isonodes( graph_train.G, message_edges, ), negative_edges=graph_train.negative_edges) graph_train._create_label_link_pred( graph_train, objective_edges, list(graph_train.G.nodes(data=True))) return graph_train
def test_hetero_graph_none(self): G = generate_simple_hete_graph(add_edge_type=False) hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) message_types = hete.message_types for message_type in message_types: self.assertEqual(message_type[1], None)
def test_hetero_multigraph_split(self): G = generate_dense_hete_multigraph() hete = HeteroGraph(G) # node hete_node = hete.split(task='node') for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual(len(hete_node[0].node_label_index[node_type]), node_0) self.assertEqual(len(hete_node[1].node_label_index[node_type]), node_1) self.assertEqual(len(hete_node[2].node_label_index[node_type]), node_2) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual(hete_edge[0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual(hete_edge[1].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual(hete_edge[2].edge_label_index[edge_type].shape[1], edge_2) # link prediction hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) # calculate the expected edge num for each splitted subgraph hete_link_train_edge_num, hete_link_val_edge_num, hete_link_test_edge_num = 0, 0, 0 for key, val in hete.edge_label_index.items(): val_length = val.shape[1] val_length_reduced = val_length - 3 hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced) hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced) hete_link_test_edge_num += \ val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced) self.assertEqual(len(hete_link[0].edge_label), hete_link_train_edge_num) self.assertEqual(len(hete_link[1].edge_label), hete_link_val_edge_num) self.assertEqual(len(hete_link[2].edge_label), hete_link_test_edge_num)
def test_hetero_graph_batch(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) heteGraphDataset = [] for i in range(30): heteGraphDataset.append(hete.clone()) dataloader = DataLoader(heteGraphDataset, collate_fn=Batch.collate(), batch_size=3, shuffle=True) self.assertEqual(len(dataloader), math.ceil(30 / 3)) for data in dataloader: self.assertEqual(data.num_graphs, 3)
def main(): args = arg_parse() edge_train_mode = args.mode print('edge train mode: {}'.format(edge_train_mode)) G = nx.read_gpickle(args.data_path) print(G.number_of_edges()) print('Each node has node ID (n_id). Example: ', G.nodes[0]) print( 'Each edge has edge ID (id) and categorical label (e_label). Example: ', G[0][5871]) # find num edge types max_label = 0 labels = [] for u, v, edge_key in G.edges: l = G[u][v][edge_key]['e_label'] if not l in labels: labels.append(l) # labels are consecutive (0-17) num_edge_types = len(labels) H = WN_transform(G, num_edge_types) # The nodes in the graph have the features: node_feature and node_type (just one node type "n1" here) for node in H.nodes(data=True): print(node) break # The edges in the graph have the features: edge_feature and edge_type ("0" - "17" here) for edge in H.edges(data=True): print(edge) break hete = HeteroGraph(H) dataset = GraphDataset([hete], task='link_pred') dataset_train, dataset_val, dataset_test = dataset.split( transductive=True, split_ratio=[0.8, 0.1, 0.1]) train_loader = DataLoader(dataset_train, collate_fn=Batch.collate(), batch_size=1) val_loader = DataLoader(dataset_val, collate_fn=Batch.collate(), batch_size=1) test_loader = DataLoader(dataset_test, collate_fn=Batch.collate(), batch_size=1) dataloaders = { 'train': train_loader, 'val': val_loader, 'test': test_loader } hidden_size = 32 model = HeteroNet(hete, hidden_size, 0.2).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) train(model, dataloaders, optimizer, args)
def _custom_hete_split_link_pred(self): split_num = len(self.graphs[0].general_splits) split_graphs = [[] for x in range(split_num)] for i in range(len(self.graphs)): graph = self.graphs[i] graph_train = copy.copy(graph) edges_train = graph_train.general_splits[0] edges_val = graph_train.general_splits[1] graph_train = HeteroGraph( graph_train._edge_subgraph_with_isonodes( graph_train.G, edges_train, ), disjoint_split=(graph_train.disjoint_split), negative_edges=(graph_train.negative_edges)) graph_val = copy.copy(graph_train) if split_num == 3: graph_test = copy.copy(graph) edges_test = graph.general_splits[2] graph_test = HeteroGraph( graph_test._edge_subgraph_with_isonodes( graph_test.G, edges_train + edges_val), negative_edges=(graph_test.negative_edges)) graph_train._create_label_link_pred( graph_train, edges_train, list(graph_train.G.nodes(data=True))) graph_val._create_label_link_pred( graph_val, edges_val, list(graph_val.G.nodes(data=True))) if split_num == 3: graph_test._create_label_link_pred( graph_test, edges_test, list(graph_test.G.nodes(data=True))) split_graphs[0].append(graph_train) split_graphs[1].append(graph_val) if split_num == 3: split_graphs[2].append(graph_test) return split_graphs
def test_hetero_graph_basics(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) self.assertEqual(hete.get_num_node_features('n1'), 10) self.assertEqual(hete.get_num_node_features('n2'), 12) self.assertEqual(hete.get_num_edge_features('e1'), 8) self.assertEqual(hete.get_num_edge_features('e2'), 12) self.assertEqual(hete.get_num_nodes('n1'), 4) self.assertEqual(hete.get_num_nodes('n2'), 5) self.assertEqual(len(hete.node_types), 2) self.assertEqual(len(hete.edge_types), 2) message_types = hete.message_types self.assertEqual(len(message_types), 7) self.assertEqual(hete.get_num_node_labels('n1'), 2) self.assertEqual(hete.get_num_node_labels('n2'), 2) self.assertEqual(hete.get_num_edge_labels('e1'), 3) self.assertEqual(hete.get_num_edge_labels('e2'), 3) self.assertEqual(hete.get_num_edges(message_types[0]), 3) self.assertEqual(len(hete.node_label_index), 2)
def test_dataset_hetero_graph_split(self): G = generate_dense_hete_dataset() hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) # node dataset = GraphDataset([hete], task="node") split_res = dataset.split() for node_type in hete.node_label_index: num_nodes = int(len(hete.node_label_index[node_type])) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(split_res[0][0].node_label_index[node_type]), node_0, ) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), node_1, ) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), node_2, ) # node with specified split type dataset = GraphDataset([hete], task="node") node_split_types = ["n1"] split_res = dataset.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = int(len(hete.node_label_index[node_type])) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(split_res[0][0].node_label_index[node_type]), node_0, ) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), node_1, ) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), node_2, ) else: num_nodes = int(len(hete.node_label_index[node_type])) self.assertEqual( len(split_res[0][0].node_label_index[node_type]), num_nodes, ) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), num_nodes, ) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), num_nodes, ) # node with specified split type (string mode) dataset = GraphDataset([hete], task="node") node_split_types = "n1" split_res = dataset.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = int(len(hete.node_label_index[node_type])) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(split_res[0][0].node_label_index[node_type]), node_0, ) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), node_1, ) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), node_2, ) else: num_nodes = int(len(hete.node_label_index[node_type])) self.assertEqual( len(split_res[0][0].node_label_index[node_type]), num_nodes, ) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), num_nodes, ) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), num_nodes, ) # edge dataset = GraphDataset([hete], task="edge") split_res = dataset.split() for edge_type in hete.edge_label_index: num_edges = hete.edge_label_index[edge_type].shape[1] edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type dataset = GraphDataset([hete], task="edge") edge_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")] split_res = dataset.split(split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = hete.edge_label_index[edge_type].shape[1] edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], edge_2, ) else: num_edges = hete.edge_label_index[edge_type].shape[1] self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], num_edges, ) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], num_edges, ) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], num_edges, ) # link_pred dataset = GraphDataset([hete], task="link_pred") split_res = dataset.split(transductive=True) for edge_type in hete.edge_label_index: num_edges = hete.edge_label_index[edge_type].shape[1] edge_0 = 2 * int(0.8 * num_edges) edge_1 = 2 * int(0.1 * num_edges) edge_2 = 2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges)) self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], edge_2) # link_pred with specified split type dataset = GraphDataset([hete], task="link_pred") link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")] split_res = dataset.split(transductive=True, split_types=link_split_types) for edge_type in hete.edge_label_index: if edge_type in link_split_types: num_edges = hete.edge_label_index[edge_type].shape[1] edge_0 = 2 * int(0.8 * num_edges) edge_1 = 2 * int(0.1 * num_edges) edge_2 = 2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges)) self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], edge_2) else: num_edges = hete.edge_label_index[edge_type].shape[1] self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], num_edges) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], num_edges) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], num_edges) # link_pred + disjoint dataset = GraphDataset( [hete], task="link_pred", edge_train_mode="disjoint", edge_message_ratio=0.5, ) split_res = dataset.split( transductive=True, split_ratio=[0.6, 0.2, 0.2], ) for edge_type in hete.edge_label_index: num_edges = hete.edge_label_index[edge_type].shape[1] edge_0 = int(0.6 * num_edges) edge_0 = 2 * (edge_0 - int(0.5 * edge_0)) edge_1 = 2 * int(0.2 * num_edges) edge_2 = 2 * (num_edges - int(0.6 * num_edges) - int(0.2 * num_edges)) self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], edge_2, ) # link pred with edge_split_mode set to "exact" dataset = GraphDataset([hete], task="link_pred", edge_split_mode="approximate") split_res = dataset.split(transductive=True) hete_link_train_edge_num = 0 hete_link_test_edge_num = 0 hete_link_val_edge_num = 0 num_edges = 0 for edge_type in hete.edge_label_index: num_edges += hete.edge_label_index[edge_type].shape[1] if edge_type in split_res[0][0].edge_label_index: hete_link_train_edge_num += ( split_res[0][0].edge_label_index[edge_type].shape[1]) if edge_type in split_res[1][0].edge_label_index: hete_link_test_edge_num += ( split_res[1][0].edge_label_index[edge_type].shape[1]) if edge_type in split_res[2][0].edge_label_index: hete_link_val_edge_num += ( split_res[2][0].edge_label_index[edge_type].shape[1]) # num_edges_reduced = num_edges - 3 edge_0 = 2 * int(0.8 * num_edges) edge_1 = 2 * int(0.1 * num_edges) edge_2 = 2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges)) self.assertEqual(hete_link_train_edge_num, edge_0) self.assertEqual(hete_link_test_edge_num, edge_1) self.assertEqual(hete_link_val_edge_num, edge_2) # link pred with specified types and edge_split_mode set to "exact" dataset = GraphDataset( [hete], task="link_pred", edge_split_mode="approximate", ) link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")] split_res = dataset.split( transductive=True, split_types=link_split_types, ) hete_link_train_edge_num = 0 hete_link_test_edge_num = 0 hete_link_val_edge_num = 0 num_split_type_edges = 0 num_non_split_type_edges = 0 for edge_type in hete.edge_label_index: if edge_type in link_split_types: num_split_type_edges += ( hete.edge_label_index[edge_type].shape[1]) else: num_non_split_type_edges += ( hete.edge_label_index[edge_type].shape[1]) if edge_type in split_res[0][0].edge_label_index: hete_link_train_edge_num += ( split_res[0][0].edge_label_index[edge_type].shape[1]) if edge_type in split_res[1][0].edge_label_index: hete_link_test_edge_num += ( split_res[1][0].edge_label_index[edge_type].shape[1]) if edge_type in split_res[2][0].edge_label_index: hete_link_val_edge_num += ( split_res[2][0].edge_label_index[edge_type].shape[1]) # num_edges_reduced = num_split_type_edges - 3 num_edges = num_split_type_edges edge_0 = 2 * int(0.8 * num_edges) + num_non_split_type_edges edge_1 = 2 * int(0.1 * num_edges) + num_non_split_type_edges edge_2 = 2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges)) + num_non_split_type_edges self.assertEqual(hete_link_train_edge_num, edge_0) self.assertEqual(hete_link_test_edge_num, edge_1) self.assertEqual(hete_link_val_edge_num, edge_2)
node_feature["citeseer_node"] = citeseer_x node_label = {} node_label["cora_node"] = cora_y node_label["citeseer_node"] = citeseer_y # prepare undirected edge_indx for message_type in edge_index: edge_index[message_type] = torch.cat([ edge_index[message_type], torch.flip(edge_index[message_type], [0]) ], dim=1) hete = HeteroGraph(node_feature=node_feature, node_label=node_label, edge_index=edge_index, directed=False) print( f"Heterogeneous graph {hete.num_nodes()} nodes, {hete.num_edges()} edges" ) dataset = GraphDataset([hete], task='node') dataset_train, dataset_val, dataset_test = dataset.split( transductive=True, split_ratio=[0.8, 0.1, 0.1]) train_loader = DataLoader(dataset_train, collate_fn=Batch.collate(), batch_size=16) val_loader = DataLoader(dataset_val, collate_fn=Batch.collate(), batch_size=16) test_loader = DataLoader(dataset_test,
pred = logits[node_type][node_idx] pred = pred.max(1)[1] acc += pred.eq(batch.node_label[node_type][node_idx].to(device)).sum().item() total += pred.size(0) acc /= total accs.append(acc) if accs[1] > best_val: best_val = accs[1] best_model = copy.deepcopy(model) return accs if __name__ == "__main__": cora_pyg = Planetoid('./cora', 'Cora') citeseer_pyg = Planetoid('./citeseer', 'CiteSeer') G = concatenate_citeseer_cora(cora_pyg[0], citeseer_pyg[0]) hete = HeteroGraph(G) print("Heterogeneous graph {} nodes, {} edges".format(hete.num_nodes, hete.num_edges)) dataset = GraphDataset([hete], task='node') dataset_train, dataset_val, dataset_test = dataset.split(transductive=True, split_ratio=[0.8, 0.1, 0.1]) train_loader = DataLoader(dataset_train, collate_fn=Batch.collate(), batch_size=16) val_loader = DataLoader(dataset_val, collate_fn=Batch.collate(), batch_size=16) test_loader = DataLoader(dataset_test, collate_fn=Batch.collate(), batch_size=16) loaders = [train_loader, val_loader, test_loader] hidden_size = 32 model = HeteroNet(hete, hidden_size, 0.5).to(device)
def test_dataset_hetero_graph_split(self): G = generate_dense_hete_dataset() hete = HeteroGraph(G) # node dataset = GraphDataset([hete], task='node') split_res = dataset.split() for node_type in hete.node_label_index: num_nodes = int(len(hete.node_label_index[node_type])) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(split_res[0][0].node_label_index[node_type]), node_0) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), node_1) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), node_2) # node with specified split type dataset = GraphDataset([hete], task='node') node_split_types = ['n1'] split_res = dataset.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = int(len(hete.node_label_index[node_type])) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(split_res[0][0].node_label_index[node_type]), node_0) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), node_1) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), node_2) else: num_nodes = int(len(hete.node_label_index[node_type])) self.assertEqual( len(split_res[0][0].node_label_index[node_type]), num_nodes) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), num_nodes) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), num_nodes) # node with specified split type (string mode) dataset = GraphDataset([hete], task='node') node_split_types = 'n1' split_res = dataset.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = int(len(hete.node_label_index[node_type])) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(split_res[0][0].node_label_index[node_type]), node_0) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), node_1) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), node_2) else: num_nodes = int(len(hete.node_label_index[node_type])) self.assertEqual( len(split_res[0][0].node_label_index[node_type]), num_nodes) self.assertEqual( len(split_res[1][0].node_label_index[node_type]), num_nodes) self.assertEqual( len(split_res[2][0].node_label_index[node_type]), num_nodes) # edge dataset = GraphDataset([hete], task='edge') split_res = dataset.split() for edge_type in hete.edge_label_index: num_edges = hete.edge_label_index[edge_type].shape[1] num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], edge_2) # edge with specified split type dataset = GraphDataset([hete], task='edge') edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] split_res = dataset.split(split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = hete.edge_label_index[edge_type].shape[1] num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], edge_2) else: num_edges = hete.edge_label_index[edge_type].shape[1] self.assertEqual( split_res[0][0].edge_label_index[edge_type].shape[1], num_edges) self.assertEqual( split_res[1][0].edge_label_index[edge_type].shape[1], num_edges) self.assertEqual( split_res[2][0].edge_label_index[edge_type].shape[1], num_edges) # link_pred dataset = GraphDataset([hete], task='link_pred') split_res = dataset.split(transductive=True) for edge_type in hete.edge_label_index: num_edges = hete.edge_label_index[edge_type].shape[1] num_edges_reduced = num_edges - 3 self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1], (2 * (1 + int(0.8 * (num_edges_reduced))))) self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1], (2 * (1 + (int(0.1 * (num_edges_reduced)))))) self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1], 2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) + int(0.8 * num_edges_reduced))) # link_pred with specified split type dataset = GraphDataset([hete], task='link_pred') link_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] split_res = dataset.split(transductive=True, split_types=link_split_types) for edge_type in hete.edge_label_index: if edge_type in link_split_types: num_edges = hete.edge_label_index[edge_type].shape[1] num_edges_reduced = num_edges - 3 self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1], (2 * (1 + int(0.8 * (num_edges_reduced))))) self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1], (2 * (1 + (int(0.1 * (num_edges_reduced)))))) self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1], 2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) + int(0.8 * num_edges_reduced))) else: num_edges = hete.edge_label_index[edge_type].shape[1] self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1], (1 * (0 + int(1.0 * (num_edges))))) self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1], (1 * (0 + (int(1.0 * (num_edges)))))) self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1], 1 * (0 + (int(1.0 * (num_edges))))) # link_pred + disjoint dataset = GraphDataset([hete], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.5) split_res = dataset.split(transductive=True, split_ratio=[0.6, 0.2, 0.2]) for edge_type in hete.edge_label_index: num_edges = hete.edge_label_index[edge_type].shape[1] num_edges_reduced = num_edges - 3 edge_0 = (1 + int(0.6 * num_edges_reduced)) edge_0 = 2 * (edge_0 - (1 + int(0.5 * (edge_0 - 2)))) self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1], edge_0) edge_1 = 2 * (1 + int(0.2 * num_edges_reduced)) self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1], edge_1) edge_2 = 2 * int(num_edges) - \ (2 * (1 + int(0.6 * num_edges_reduced))) - edge_1 self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1], edge_2) # link pred with edge_split_mode set to "exact" dataset = GraphDataset([hete], task='link_pred', edge_split_mode="approximate") split_res = dataset.split(transductive=True) hete_link_train_edge_num = 0 hete_link_test_edge_num = 0 hete_link_val_edge_num = 0 num_edges = 0 for edge_type in hete.edge_label_index: num_edges += hete.edge_label_index[edge_type].shape[1] if edge_type in split_res[0][0].edge_label_index: hete_link_train_edge_num += split_res[0][0].edge_label_index[edge_type].shape[1] if edge_type in split_res[1][0].edge_label_index: hete_link_test_edge_num += split_res[1][0].edge_label_index[edge_type].shape[1] if edge_type in split_res[2][0].edge_label_index: hete_link_val_edge_num += split_res[2][0].edge_label_index[edge_type].shape[1] num_edges_reduced = num_edges - 3 self.assertEqual(hete_link_train_edge_num, (2 * (1 + int(0.8 * (num_edges_reduced))))) self.assertEqual(hete_link_test_edge_num, (2 * (1 + (int(0.1 * (num_edges_reduced)))))) self.assertEqual(hete_link_val_edge_num, 2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) + int(0.8 * num_edges_reduced))) # link pred with specified types and edge_split_mode set to "exact" dataset = GraphDataset([hete], task='link_pred', edge_split_mode="approximate") link_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] split_res = dataset.split(transductive=True, split_types=link_split_types) hete_link_train_edge_num = 0 hete_link_test_edge_num = 0 hete_link_val_edge_num = 0 num_split_type_edges = 0 num_non_split_type_edges = 0 for edge_type in hete.edge_label_index: if edge_type in link_split_types: num_split_type_edges += hete.edge_label_index[edge_type].shape[1] else: num_non_split_type_edges += hete.edge_label_index[edge_type].shape[1] if edge_type in split_res[0][0].edge_label_index: hete_link_train_edge_num += split_res[0][0].edge_label_index[edge_type].shape[1] if edge_type in split_res[1][0].edge_label_index: hete_link_test_edge_num += split_res[1][0].edge_label_index[edge_type].shape[1] if edge_type in split_res[2][0].edge_label_index: hete_link_val_edge_num += split_res[2][0].edge_label_index[edge_type].shape[1] num_edges_reduced = num_split_type_edges - 3 edge_0 = 2 * (1 + int(0.8 * (num_edges_reduced))) + num_non_split_type_edges edge_1 = 2 * (1 + int(0.1 * (num_edges_reduced))) + num_non_split_type_edges edge_2 = 2 * num_split_type_edges - 2 * (2 + int(0.1 * num_edges_reduced) + \ int(0.8 * num_edges_reduced)) + num_non_split_type_edges self.assertEqual(hete_link_train_edge_num, edge_0) self.assertEqual(hete_link_test_edge_num, edge_1) self.assertEqual(hete_link_val_edge_num, edge_2)
def main(): args = arg_parse() edge_train_mode = args.mode print('edge train mode: {}'.format(edge_train_mode)) G = nx.read_gpickle(args.data_path) print(G.number_of_edges()) print('Each node has node ID (n_id). Example: ', G.nodes[0]) print( 'Each edge has edge ID (id) and categorical label (e_label). Example: ', G[0][5871]) # find num edge types max_label = 0 labels = [] for u, v, edge_key in G.edges: l = G[u][v][edge_key]['e_label'] if not l in labels: labels.append(l) # labels are consecutive (0-17) num_edge_types = len(labels) H = WN_transform(G, num_edge_types) # The nodes in the graph have the features: node_feature and node_type (just one node type "n1" here) for node in H.nodes(data=True): print(node) break # The edges in the graph have the features: edge_feature and edge_type ("0" - "17" here) for edge in H.edges(data=True): print(edge) break hetero = HeteroGraph(H) hetero = HeteroGraph(edge_index=hetero.edge_index, edge_feature=hetero.edge_feature, node_feature=hetero.node_feature, directed=hetero.is_directed()) if edge_train_mode == "disjoint": dataset = GraphDataset([hetero], task='link_pred', edge_train_mode=edge_train_mode, edge_message_ratio=args.edge_message_ratio) else: dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode=edge_train_mode, ) dataset_train, dataset_val, dataset_test = dataset.split( transductive=True, split_ratio=[0.8, 0.1, 0.1]) train_loader = DataLoader(dataset_train, collate_fn=Batch.collate(), batch_size=1) val_loader = DataLoader(dataset_val, collate_fn=Batch.collate(), batch_size=1) test_loader = DataLoader(dataset_test, collate_fn=Batch.collate(), batch_size=1) dataloaders = { 'train': train_loader, 'val': val_loader, 'test': test_loader } hidden_size = args.hidden_dim conv1, conv2 = generate_2convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_size) model = HeteroGNN(conv1, conv2, hetero, hidden_size).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) t_accu, v_accu, e_accu = train(model, dataloaders, optimizer, args)
def main(): writer = SummaryWriter() args = arg_parse() edge_train_mode = args.mode print('edge train mode: {}'.format(edge_train_mode)) ppi_graph = read_ppi_data(args.ppi_path) mode = 'mixed' if mode == 'ppi': message_passing_graph = ppi_graph cmap_graph, knockout_nodes = read_cmap_data(args.data_path) elif mode == 'mixed': message_passing_graph, knockout_nodes = ( read_cmap_data(args.data_path, ppi_graph) ) print('Each node has gene ID. Example: ', message_passing_graph.nodes['ADPGK']) print('Each edge has de direction. Example', message_passing_graph['ADPGK']['IL1B']) print('Total num edges: ', message_passing_graph.number_of_edges()) # disjoint edge label disjoint_split_ratio = 0.1 val_ratio = 0.1 disjoint_edge_label_index = [] val_edges = [] # newly edited train_edges = [] for u in knockout_nodes: rand_num = np.random.rand() if rand_num < disjoint_split_ratio: # add all edges (cmap only) into edge label index # cmap is not a multigraph disjoint_edge_label_index.extend( [ (u, v, edge_key) for v in message_passing_graph.successors(u) for edge_key in message_passing_graph[u][v] if message_passing_graph[u][v][edge_key]['edge_type'] == 1 ] ) train_edges.extend( [ (u, v, edge_key) for v in message_passing_graph.successors(u) for edge_key in message_passing_graph[u][v] if message_passing_graph[u][v][edge_key]['edge_type'] == 1 ] ) elif rand_num < disjoint_split_ratio + val_ratio: val_edges.extend( [ (u, v, edge_key) for v in message_passing_graph.successors(u) for edge_key in message_passing_graph[u][v] if message_passing_graph[u][v][edge_key]['edge_type'] == 1 ] ) else: train_edges.extend( [ (u, v, edge_key) for v in message_passing_graph.successors(u) for edge_key in message_passing_graph[u][v] if message_passing_graph[u][v][edge_key]['edge_type'] == 1 ] ) # add default node types for message_passing_graph for node in message_passing_graph.nodes: message_passing_graph.nodes[node]['node_type'] = 0 print('Num edges to predict: ', len(disjoint_edge_label_index)) print('Num edges in val: ', len(val_edges)) print('Num edges in train: ', len(train_edges)) graph = HeteroGraph( message_passing_graph, custom={ "general_splits": [ train_edges, val_edges ], "disjoint_split": disjoint_edge_label_index, "task": "link_pred" } ) graphs = [graph] graphDataset = GraphDataset( graphs, task="link_pred", edge_train_mode="disjoint" ) # Transform dataset # de direction (currently using homogeneous graph) num_edge_types = 2 graphDataset = graphDataset.apply_transform( cmap_transform, num_edge_types=num_edge_types, deep_copy=False ) print('Number of node features: ', graphDataset.num_node_features()) # split dataset dataset = {} dataset['train'], dataset['val'] = graphDataset.split(transductive=True) # sanity check print(f"dataset['train'][0].edge_label_index.keys(): {dataset['train'][0].edge_label_index.keys()}") print(f"dataset['train'][0].edge_label_index[(0, 1, 0)].shape[1]: {dataset['train'][0].edge_label_index[(0, 1, 0)].shape[1]}") print(f"dataset['val'][0].edge_label_index.keys(): {dataset['val'][0].edge_label_index.keys()}") print(f"dataset['val'][0].edge_label_index[(0, 1, 0)].shape[1]: {dataset['val'][0].edge_label_index[(0, 1, 0)].shape[1]}") print(f"len(list(dataset['train'][0].G.edges)): {len(list(dataset['train'][0].G.edges))}") print(f"len(list(dataset['val'][0].G.edges)): {len(list(dataset['val'][0].G.edges))}") print(f"list(dataset['train'][0].G.edges)[:10]: {list(dataset['train'][0].G.edges)[:10]}") print(f"list(dataset['val'][0].G.edges)[:10]: {list(dataset['val'][0].G.edges)[:10]}") # node feature dimension input_dim = dataset['train'].num_node_features() edge_feat_dim = dataset['train'].num_edge_features() num_classes = dataset['train'].num_edge_labels() print( 'Node feature dim: {}; edge feature dim: {}; num classes: {}.'.format( input_dim, edge_feat_dim, num_classes ) ) exit() # relation type is both used for edge features and edge labels model = Net(input_dim, edge_feat_dim, num_classes, args).to(args.device) optimizer = torch.optim.Adam( model.parameters(), lr=0.001, weight_decay=5e-3 ) follow_batch = [] # e.g., follow_batch = ['edge_index'] dataloaders = { split: DataLoader( ds, collate_fn=Batch.collate(follow_batch), batch_size=1, shuffle=(split == 'train') ) for split, ds in dataset.items() } print('Graphs after split: ') for key, dataloader in dataloaders.items(): for batch in dataloader: print(key, ': ', batch) train(model, dataloaders, optimizer, args, writer=writer)
def test_hetero_multigraph_split(self): G = generate_dense_hete_multigraph() hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) # node hete_node = hete.split(task='node') for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # link prediction hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) # calculate the expected edge num for each splitted subgraph edge_0, edge_1, edge_2 = 0, 0, 0 for _, val in hete.edge_label_index.items(): num_edges = val.shape[1] edge_0 += int(0.5 * num_edges) edge_1 += int(0.3 * num_edges) edge_2 += num_edges - int(0.5 * num_edges) - int(0.3 * num_edges) train_edge_num = sum([ hete_link[0].edge_label[message_type].shape[0] for message_type in hete_link[0].edge_label ]) val_edge_num = sum([ hete_link[1].edge_label[message_type].shape[0] for message_type in hete_link[1].edge_label ]) test_edge_num = sum([ hete_link[2].edge_label[message_type].shape[0] for message_type in hete_link[2].edge_label ]) self.assertEqual(train_edge_num, edge_0) self.assertEqual(val_edge_num, edge_1) self.assertEqual(test_edge_num, edge_2)
def test_hetero_graph_split(self): # directed G G = generate_dense_hete_graph() hete = HeteroGraph(G) hete = HeteroGraph( node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, ) # node hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.edge_label_index[edge_type].shape[1]) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): num_edges = val.shape[1] edge_0 = int(0.5 * num_edges) edge_1 = int(0.3 * num_edges) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0) self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1) self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2) # undirected G G = generate_dense_hete_graph(directed=False) hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=False) # node hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) node_0 = int(num_nodes * 0.8) node_1 = int(num_nodes * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.num_edges(edge_type)) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.num_edges(edge_type)) edge_0 = int(num_edges * 0.8) edge_1 = int(num_edges * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): num_edges = int(val.shape[1] / 2) edge_0 = 2 * int(0.5 * num_edges) edge_1 = 2 * int(0.3 * num_edges) edge_2 = 2 * (num_edges - int(0.5 * num_edges) - int(0.3 * num_edges)) self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0) self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1) self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2)
def test_hetero_multigraph_split(self): G = generate_dense_hete_multigraph() hete = HeteroGraph(G) hete = HeteroGraph( node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True ) # node hete_node = hete.split(task='node') for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # link prediction hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) # calculate the expected edge num for each splitted subgraph hete_link_train_edge_num = 0 hete_link_val_edge_num = 0 hete_link_test_edge_num = 0 for _, val in hete.edge_label_index.items(): val_length = val.shape[1] val_length_reduced = val_length - 3 hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced) hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced) hete_link_test_edge_num += ( val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced) ) train_edge_num = sum([ hete_link[0].edge_label[message_type].shape[0] for message_type in hete_link[0].edge_label ]) val_edge_num = sum([ hete_link[1].edge_label[message_type].shape[0] for message_type in hete_link[1].edge_label ]) test_edge_num = sum([ hete_link[2].edge_label[message_type].shape[0] for message_type in hete_link[2].edge_label ]) self.assertEqual( train_edge_num, hete_link_train_edge_num ) self.assertEqual( val_edge_num, hete_link_val_edge_num ) self.assertEqual( test_edge_num, hete_link_test_edge_num, )
def test_hetero_graph_none(self): G = generate_simple_hete_graph(no_edge_type=True) hete = HeteroGraph(G) message_types = hete.message_types for message_type in message_types: self.assertEqual(message_type[1], None)
def test_hetero_graph_split(self): G = generate_dense_hete_graph() hete = HeteroGraph(G) hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual(len(hete_node[0].node_label_index[node_type]), node_0) self.assertEqual(len(hete_node[1].node_label_index[node_type]), node_1) self.assertEqual(len(hete_node[2].node_label_index[node_type]), node_2) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if (node_type in node_split_types): num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual(len(hete_node[0].node_label_index[node_type]), node_0) self.assertEqual(len(hete_node[1].node_label_index[node_type]), node_1) self.assertEqual(len(hete_node[2].node_label_index[node_type]), node_2) else: self.assertEqual(len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type])) self.assertEqual(len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type])) self.assertEqual(len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type])) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual(hete_edge[0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual(hete_edge[1].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual(hete_edge[2].edge_label_index[edge_type].shape[1], edge_2) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if (edge_type in edge_split_types): num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1]) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1]) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1]) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) # calculate the expected edge num for each splitted subgraph hete_link_train_edge_num, hete_link_val_edge_num, hete_link_test_edge_num = 0, 0, 0 for key, val in hete.edge_label_index.items(): val_length = val.shape[1] val_length_reduced = val_length - 3 hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced) hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced) hete_link_test_edge_num += \ val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced) self.assertEqual(len(hete_link[0].edge_label), hete_link_train_edge_num) self.assertEqual(len(hete_link[1].edge_label), hete_link_val_edge_num) self.assertEqual(len(hete_link[2].edge_label), hete_link_test_edge_num)
def test_hetero_graph_split(self): # directed G G = generate_dense_hete_graph() hete = HeteroGraph(G) hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) # link_pred hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): val_length = val.shape[1] val_length_reduced = val_length - 3 hete_link_train_edge_num = 1 + int(0.5 * val_length_reduced) hete_link_val_edge_num = 1 + int(0.3 * val_length_reduced) hete_link_test_edge_num = (val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced)) self.assertEqual(hete_link[0].edge_label[key].shape[0], hete_link_train_edge_num) self.assertEqual(hete_link[1].edge_label[key].shape[0], hete_link_val_edge_num) self.assertEqual(hete_link[2].edge_label[key].shape[0], hete_link_test_edge_num) # undirected G G = generate_dense_hete_graph(directed=False) hete = HeteroGraph(G) hete_node = hete.split() for node_type in hete.node_label_index: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) # node with specified split type node_split_types = ['n1'] hete_node = hete.split(split_types=node_split_types) for node_type in hete.node_label_index: if node_type in node_split_types: num_nodes = len(hete.node_label_index[node_type]) num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 self.assertEqual( len(hete_node[0].node_label_index[node_type]), node_0, ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), node_1, ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), node_2, ) else: self.assertEqual( len(hete_node[0].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[1].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) self.assertEqual( len(hete_node[2].node_label_index[node_type]), len(hete.node_label_index[node_type]), ) # edge hete_edge = hete.split(task='edge') for edge_type in hete.edge_label_index: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) # edge with specified split type edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')] hete_edge = hete.split(task='edge', split_types=edge_split_types) for edge_type in hete.edge_label_index: if edge_type in edge_split_types: num_edges = int(hete.edge_label_index[edge_type].shape[1]) num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], edge_0, ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], edge_1, ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], edge_2, ) else: self.assertEqual( hete_edge[0].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[1].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) self.assertEqual( hete_edge[2].edge_label_index[edge_type].shape[1], hete.edge_label_index[edge_type].shape[1], ) hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2]) for key, val in hete.edge_label_index.items(): val_length = val.shape[1] hete_link_train_edge_num = (2 * (1 + int(0.5 * (int(val_length / 2) - 3)))) hete_link_val_edge_num = (2 * (1 + int(0.3 * (int(val_length / 2) - 3)))) hete_link_test_edge_num = (val_length - hete_link_train_edge_num - hete_link_val_edge_num) self.assertEqual(hete_link[0].edge_label[key].shape[0], hete_link_train_edge_num) self.assertEqual(hete_link[1].edge_label[key].shape[0], hete_link_val_edge_num) self.assertEqual(hete_link[2].edge_label[key].shape[0], hete_link_test_edge_num)
def test_secure_split_heterogeneous(self): G = generate_simple_small_hete_graph() graph = HeteroGraph(G) graph = HeteroGraph(node_label=graph.node_label, edge_index=graph.edge_index, edge_label=graph.edge_label, directed=True) graphs = [graph] # node task dataset = GraphDataset(graphs, task="node") split_res = dataset.split() for node_type in graph.node_label_index: num_nodes = graph.node_label_index[node_type].shape[0] num_nodes_reduced = num_nodes - 3 node_0 = 1 + int(num_nodes_reduced * 0.8) node_1 = 1 + int(num_nodes_reduced * 0.1) node_2 = num_nodes - node_0 - node_1 node_size = [node_0, node_1, node_2] for i in range(3): self.assertEqual( split_res[i][0].node_label_index[node_type].shape[0], node_size[i]) self.assertEqual( split_res[i][0].node_label[node_type].shape[0], node_size[i]) # edge task dataset = GraphDataset(graphs, task="edge") split_res = dataset.split() for message_type in graph.edge_label_index: num_edges = graph.edge_label_index[message_type].shape[1] num_edges_reduced = num_edges - 3 edge_0 = 1 + int(num_edges_reduced * 0.8) edge_1 = 1 + int(num_edges_reduced * 0.1) edge_2 = num_edges - edge_0 - edge_1 edge_size = [edge_0, edge_1, edge_2] for i in range(3): self.assertEqual( split_res[i][0].edge_label_index[message_type].shape[1], edge_size[i]) self.assertEqual( split_res[i][0].edge_label[message_type].shape[0], edge_size[i]) # link_pred task dataset = GraphDataset(graphs, task="link_pred") split_res = dataset.split() for message_type in graph.edge_label_index: num_edges = graph.edge_label_index[message_type].shape[1] num_edges_reduced = num_edges - 3 edge_0 = 2 * (1 + int(num_edges_reduced * 0.8)) edge_1 = 2 * (1 + int(num_edges_reduced * 0.1)) edge_2 = 2 * num_edges - edge_0 - edge_1 edge_size = [edge_0, edge_1, edge_2] for i in range(3): self.assertEqual( split_res[i][0].edge_label_index[message_type].shape[1], edge_size[i]) self.assertEqual( split_res[i][0].edge_label[message_type].shape[0], edge_size[i])
# Dictionary of node features node_feature = {} node_feature["paper"] = data['feature'] # Dictionary of node labels node_label = {} node_label["paper"] = data['label'] # Load the train, validation and test indices train_idx = {"paper": data['train_idx'].to(args.device)} val_idx = {"paper": data['val_idx'].to(args.device)} test_idx = {"paper": data['test_idx'].to(args.device)} # Construct a deepsnap tensor backend HeteroGraph hetero_graph = HeteroGraph(node_feature=node_feature, node_label=node_label, edge_index=edge_index, directed=True) print( f"ACM heterogeneous graph: {hetero_graph.num_nodes()} nodes, {hetero_graph.num_edges()} edges" ) # Node feature and node label to device for key in hetero_graph.node_feature: hetero_graph.node_feature[key] = hetero_graph.node_feature[key].to( args.device) for key in hetero_graph.node_label: hetero_graph.node_label[key] = hetero_graph.node_label[key].to( args.device) # Edge_index to sparse tensor and to device