예제 #1
0
파일: dataset.py 프로젝트: aakgun/deepsnap
 def _custom_split_link_pred_disjoint(self, graph_train):
     objective_edges = graph_train.custom_disjoint_split
     message_edges = list(set(graph_train.G.edges) - set(objective_edges))
     graph_train = Graph(
         graph_train._edge_subgraph_with_isonodes(
             graph_train.G,
             message_edges,
         ))
     graph_train._create_label_link_pred(graph_train, objective_edges)
     return graph_train
예제 #2
0
파일: dataset.py 프로젝트: aakgun/deepsnap
    def _custom_split_link_pred(self):
        split_num = len(self.graphs[0].custom_splits)
        split_graphs = [[] for x in range(split_num)]
        for i in range(len(self.graphs)):
            graph = self.graphs[i]
            graph_train = copy.copy(graph)

            edges_train = graph_train.custom_splits[0]
            edges_val = graph_train.custom_splits[1]

            graph_train = Graph(
                graph_train._edge_subgraph_with_isonodes(
                    graph_train.G,
                    edges_train,
                ),
                custom_disjoint_split=(graph_train.custom_disjoint_split))

            graph_val = copy.copy(graph_train)
            if split_num == 3:
                graph_test = copy.copy(graph)
                edges_test = graph.custom_splits[2]
                graph_test = Graph(
                    graph_test._edge_subgraph_with_isonodes(
                        graph_test.G, edges_train + edges_val))

            graph_train._create_label_link_pred(graph_train, edges_train)
            graph_val._create_label_link_pred(graph_val, edges_val)

            if split_num == 3:
                graph_test._create_label_link_pred(graph_test, edges_test)

            split_graphs[0].append(graph_train)
            split_graphs[1].append(graph_val)
            if split_num == 3:
                split_graphs[2].append(graph_test)

        return split_graphs
예제 #3
0
 def _custom_split_link_pred_disjoint(self, graph_train):
     objective_edges = graph_train.disjoint_split
     objective_edges_no_info = [edge[:-1] for edge in objective_edges]
     message_edges_no_info = (
         list(set(graph_train.G.edges) - set(objective_edges_no_info)))
     if len(message_edges_no_info[0]) == 3:
         message_edges = [(edge[0], edge[1], edge[2],
                           graph_train.G.edges[edge[0], edge[1], edge[2]])
                          for edge in message_edges_no_info]
     elif len(message_edges_no_info[0]) == 2:
         message_edges = [(edge[0], edge[1], graph_train.G.edges[edge[0],
                                                                 edge[1]])
                          for edge in message_edges_no_info]
     else:
         raise ValueError("Each edge has more than 3 indices.")
     graph_train = Graph(
         graph_train._edge_subgraph_with_isonodes(
             graph_train.G,
             message_edges,
         ))
     graph_train._create_label_link_pred(graph_train, objective_edges)
     return graph_train
예제 #4
0
    def _split_transductive(
        self,
        split_ratio: List[float],
        split_types: List[str] = None,
    ) -> List[Graph]:
        r"""
        Split the dataset assuming training process is transductive.

        Args:
            split_ratio: number of data splitted into train, validation
                (and test) set.

        Returns:
            list: A list of 3 (2) lists of :class:`deepsnap.graph.Graph` object corresponding
            to train, validation (and test) set.
        """
        if self.task == "graph":
            raise ValueError('Graph prediction task cannot be transductive')

        # a list of split graphs
        # (e.g. [[train graph, val graph, test graph], ... ])
        if self.general_split_mode == "custom":
            split_graphs = self.split_graphs
            if self.task == "link_pred":
                # TODO: handle heterogeneous graph in the future
                split_num = len(split_graphs)
                for i in range(len(split_graphs[0])):
                    graph_train = split_graphs[0][i]
                    graph_val = split_graphs[1][i]

                    edges_train = graph_train.custom_split_index
                    edges_val = graph_val.custom_split_index

                    graph_train = Graph(
                        graph_train._edge_subgraph_with_isonodes(
                            graph_train.G,
                            edges_train,
                        )
                    )
                    graph_val = copy.copy(graph_train)
                    if split_num == 3:
                        graph_test = split_graphs[2][i]
                        edges_test = graph_test.custom_split_index
                        graph_test = Graph(
                            graph_test._edge_subgraph_with_isonodes(
                                graph_test.G,
                                edges_train + edges_val
                            )
                        )

                    graph_train._create_label_link_pred(
                        graph_train, edges_train
                    )
                    graph_val._create_label_link_pred(
                        graph_val, edges_val
                    )
                    if split_num == 3:
                        graph_test._create_label_link_pred(
                            graph_test, edges_test
                        )

                    split_graphs[0][i] = graph_train
                    split_graphs[1][i] = graph_val
                    split_graphs[2][i] = graph_test

            else:
                for graphs in split_graphs:
                    for graph in graphs:
                        if self.task == "node":
                            graph.node_label_index = graph.custom_split_index
                        if self.task == "edge":
                            graph.edge_label_index = (
                                graph._edge_to_index(graph.custom_split_index)
                            )

        elif self.general_split_mode == "random":
            split_graphs = []
            for graph in self.graphs:
                if isinstance(graph, Graph):
                    if isinstance(graph, HeteroGraph):
                        split_graph = graph.split(
                            task=self.task,
                            split_types=split_types,
                            split_ratio=split_ratio,
                            edge_split_mode=self.edge_split_mode
                        )
                    else:
                        split_graph = graph.split(self.task, split_ratio)
                else:
                    raise TypeError(
                        "element in self.graphs of unexpected type"
                    )
                split_graphs.append(split_graph)
            split_graphs = list(map(list, zip(*split_graphs)))

        for i, graph in enumerate(split_graphs[0]):
            if (
                self.task == "link_pred"
                and self.edge_train_mode == "disjoint"
            ):
                if isinstance(graph, Graph):
                    if isinstance(graph, HeteroGraph):
                        graph = graph.split_link_pred(
                            split_types=split_types,
                            split_ratio=self.edge_message_ratio,
                            edge_split_mode=self.edge_split_mode
                        )[1]
                    else:
                        graph = graph.split_link_pred(
                            self.edge_message_ratio
                        )[1]
                    split_graphs[0][i] = graph
                else:
                    raise TypeError(
                        "element in self.graphs of unexpected type"
                    )

        # list of num_splits datasets
        # (e.g. [train dataset, val dataset, test dataset])
        dataset_return = []
        for x in split_graphs:
            dataset_current = copy.copy(self)
            dataset_current.graphs = x
            if self.task == "link_pred":
                for graph_temp in dataset_current.graphs:
                    if isinstance(graph_temp, Graph):
                        if isinstance(graph_temp, HeteroGraph):
                            graph_temp._create_neg_sampling(
                                negative_sampling_ratio=(
                                    self.edge_negative_sampling_ratio
                                ),
                                split_types=split_types
                            )
                        else:
                            graph_temp._create_neg_sampling(
                                self.edge_negative_sampling_ratio
                            )
                    else:
                        raise TypeError(
                            "element in self.graphs of unexpected type"
                        )
            dataset_return.append(dataset_current)

        # resample negatives for train split (only for link prediction)
        dataset_return[0]._resample_negatives = True
        return dataset_return