def _custom_hete_split_link_pred_disjoint(self, graph_train): objective_edges = graph_train.disjoint_split nodes_dict = {} for node in graph_train.G.nodes(data=True): nodes_dict[node[0]] = node[1]["node_type"] edges_dict = {} objective_edges_dict = {} for edge in graph_train.G.edges(data=True): edge_type = edge[-1]["edge_type"] head_type = nodes_dict[edge[0]] tail_type = nodes_dict[edge[1]] message_type = (head_type, edge_type, tail_type) if message_type not in edges_dict: edges_dict[message_type] = [] edges_dict[message_type].append(edge) for edge in objective_edges: edge_type = edge[-1]["edge_type"] head_type = nodes_dict[edge[0]] tail_type = nodes_dict[edge[1]] message_type = (head_type, edge_type, tail_type) if message_type not in objective_edges_dict: objective_edges_dict[message_type] = [] objective_edges_dict[message_type].append(edge) message_edges = [] for edge_type in edges_dict: if edge_type in objective_edges_dict: edges_no_info = [edge[:-1] for edge in edges_dict[edge_type]] objective_edges_no_info = [ edge[:-1] for edge in objective_edges_dict[edge_type] ] message_edges_no_info = set(edges_no_info) - set( objective_edges_no_info) message_edges += [(edge[0], edge[1], graph_train.G.edges[edge[0], edge[1]]) for edge in message_edges_no_info] else: message_edges += edges_dict[edge_type] # update objective edges for edge_type in edges_dict: if edge_type not in objective_edges_dict: objective_edges += edges_dict[edge_type] graph_train = HeteroGraph(graph_train._edge_subgraph_with_isonodes( graph_train.G, message_edges, ), negative_edges=graph_train.negative_edges) graph_train._create_label_link_pred( graph_train, objective_edges, list(graph_train.G.nodes(data=True))) return graph_train
def _custom_hete_split_link_pred(self): split_num = len(self.graphs[0].general_splits) split_graphs = [[] for x in range(split_num)] for i in range(len(self.graphs)): graph = self.graphs[i] graph_train = copy.copy(graph) edges_train = graph_train.general_splits[0] edges_val = graph_train.general_splits[1] graph_train = HeteroGraph( graph_train._edge_subgraph_with_isonodes( graph_train.G, edges_train, ), disjoint_split=(graph_train.disjoint_split), negative_edges=(graph_train.negative_edges)) graph_val = copy.copy(graph_train) if split_num == 3: graph_test = copy.copy(graph) edges_test = graph.general_splits[2] graph_test = HeteroGraph( graph_test._edge_subgraph_with_isonodes( graph_test.G, edges_train + edges_val), negative_edges=(graph_test.negative_edges)) graph_train._create_label_link_pred( graph_train, edges_train, list(graph_train.G.nodes(data=True))) graph_val._create_label_link_pred( graph_val, edges_val, list(graph_val.G.nodes(data=True))) if split_num == 3: graph_test._create_label_link_pred( graph_test, edges_test, list(graph_test.G.nodes(data=True))) split_graphs[0].append(graph_train) split_graphs[1].append(graph_val) if split_num == 3: split_graphs[2].append(graph_test) return split_graphs