Ejemplo n.º 1
0
    def _custom_hete_split_link_pred_disjoint(self, graph_train):
        objective_edges = graph_train.disjoint_split

        nodes_dict = {}
        for node in graph_train.G.nodes(data=True):
            nodes_dict[node[0]] = node[1]["node_type"]

        edges_dict = {}
        objective_edges_dict = {}

        for edge in graph_train.G.edges(data=True):
            edge_type = edge[-1]["edge_type"]
            head_type = nodes_dict[edge[0]]
            tail_type = nodes_dict[edge[1]]
            message_type = (head_type, edge_type, tail_type)
            if message_type not in edges_dict:
                edges_dict[message_type] = []
            edges_dict[message_type].append(edge)

        for edge in objective_edges:
            edge_type = edge[-1]["edge_type"]
            head_type = nodes_dict[edge[0]]
            tail_type = nodes_dict[edge[1]]
            message_type = (head_type, edge_type, tail_type)
            if message_type not in objective_edges_dict:
                objective_edges_dict[message_type] = []
            objective_edges_dict[message_type].append(edge)

        message_edges = []
        for edge_type in edges_dict:
            if edge_type in objective_edges_dict:
                edges_no_info = [edge[:-1] for edge in edges_dict[edge_type]]
                objective_edges_no_info = [
                    edge[:-1] for edge in objective_edges_dict[edge_type]
                ]
                message_edges_no_info = set(edges_no_info) - set(
                    objective_edges_no_info)
                message_edges += [(edge[0], edge[1],
                                   graph_train.G.edges[edge[0], edge[1]])
                                  for edge in message_edges_no_info]
            else:
                message_edges += edges_dict[edge_type]

        # update objective edges
        for edge_type in edges_dict:
            if edge_type not in objective_edges_dict:
                objective_edges += edges_dict[edge_type]

        graph_train = HeteroGraph(graph_train._edge_subgraph_with_isonodes(
            graph_train.G,
            message_edges,
        ),
                                  negative_edges=graph_train.negative_edges)

        graph_train._create_label_link_pred(
            graph_train, objective_edges, list(graph_train.G.nodes(data=True)))

        return graph_train
Ejemplo n.º 2
0
    def _custom_hete_split_link_pred(self):
        split_num = len(self.graphs[0].general_splits)
        split_graphs = [[] for x in range(split_num)]
        for i in range(len(self.graphs)):
            graph = self.graphs[i]
            graph_train = copy.copy(graph)

            edges_train = graph_train.general_splits[0]
            edges_val = graph_train.general_splits[1]

            graph_train = HeteroGraph(
                graph_train._edge_subgraph_with_isonodes(
                    graph_train.G,
                    edges_train,
                ),
                disjoint_split=(graph_train.disjoint_split),
                negative_edges=(graph_train.negative_edges))

            graph_val = copy.copy(graph_train)
            if split_num == 3:
                graph_test = copy.copy(graph)
                edges_test = graph.general_splits[2]
                graph_test = HeteroGraph(
                    graph_test._edge_subgraph_with_isonodes(
                        graph_test.G, edges_train + edges_val),
                    negative_edges=(graph_test.negative_edges))

            graph_train._create_label_link_pred(
                graph_train, edges_train, list(graph_train.G.nodes(data=True)))
            graph_val._create_label_link_pred(
                graph_val, edges_val, list(graph_val.G.nodes(data=True)))

            if split_num == 3:
                graph_test._create_label_link_pred(
                    graph_test, edges_test,
                    list(graph_test.G.nodes(data=True)))

            split_graphs[0].append(graph_train)
            split_graphs[1].append(graph_val)
            if split_num == 3:
                split_graphs[2].append(graph_test)

        return split_graphs