def _custom_split_link_pred_disjoint(self, graph_train): objective_edges = graph_train.custom_disjoint_split message_edges = list(set(graph_train.G.edges) - set(objective_edges)) graph_train = Graph( graph_train._edge_subgraph_with_isonodes( graph_train.G, message_edges, )) graph_train._create_label_link_pred(graph_train, objective_edges) return graph_train
def _custom_split_link_pred(self): split_num = len(self.graphs[0].custom_splits) split_graphs = [[] for x in range(split_num)] for i in range(len(self.graphs)): graph = self.graphs[i] graph_train = copy.copy(graph) edges_train = graph_train.custom_splits[0] edges_val = graph_train.custom_splits[1] graph_train = Graph( graph_train._edge_subgraph_with_isonodes( graph_train.G, edges_train, ), custom_disjoint_split=(graph_train.custom_disjoint_split)) graph_val = copy.copy(graph_train) if split_num == 3: graph_test = copy.copy(graph) edges_test = graph.custom_splits[2] graph_test = Graph( graph_test._edge_subgraph_with_isonodes( graph_test.G, edges_train + edges_val)) graph_train._create_label_link_pred(graph_train, edges_train) graph_val._create_label_link_pred(graph_val, edges_val) if split_num == 3: graph_test._create_label_link_pred(graph_test, edges_test) split_graphs[0].append(graph_train) split_graphs[1].append(graph_val) if split_num == 3: split_graphs[2].append(graph_test) return split_graphs
def _custom_split_link_pred_disjoint(self, graph_train): objective_edges = graph_train.disjoint_split objective_edges_no_info = [edge[:-1] for edge in objective_edges] message_edges_no_info = ( list(set(graph_train.G.edges) - set(objective_edges_no_info))) if len(message_edges_no_info[0]) == 3: message_edges = [(edge[0], edge[1], edge[2], graph_train.G.edges[edge[0], edge[1], edge[2]]) for edge in message_edges_no_info] elif len(message_edges_no_info[0]) == 2: message_edges = [(edge[0], edge[1], graph_train.G.edges[edge[0], edge[1]]) for edge in message_edges_no_info] else: raise ValueError("Each edge has more than 3 indices.") graph_train = Graph( graph_train._edge_subgraph_with_isonodes( graph_train.G, message_edges, )) graph_train._create_label_link_pred(graph_train, objective_edges) return graph_train
def _split_transductive( self, split_ratio: List[float], split_types: List[str] = None, ) -> List[Graph]: r""" Split the dataset assuming training process is transductive. Args: split_ratio: number of data splitted into train, validation (and test) set. Returns: list: A list of 3 (2) lists of :class:`deepsnap.graph.Graph` object corresponding to train, validation (and test) set. """ if self.task == "graph": raise ValueError('Graph prediction task cannot be transductive') # a list of split graphs # (e.g. [[train graph, val graph, test graph], ... ]) if self.general_split_mode == "custom": split_graphs = self.split_graphs if self.task == "link_pred": # TODO: handle heterogeneous graph in the future split_num = len(split_graphs) for i in range(len(split_graphs[0])): graph_train = split_graphs[0][i] graph_val = split_graphs[1][i] edges_train = graph_train.custom_split_index edges_val = graph_val.custom_split_index graph_train = Graph( graph_train._edge_subgraph_with_isonodes( graph_train.G, edges_train, ) ) graph_val = copy.copy(graph_train) if split_num == 3: graph_test = split_graphs[2][i] edges_test = graph_test.custom_split_index graph_test = Graph( graph_test._edge_subgraph_with_isonodes( graph_test.G, edges_train + edges_val ) ) graph_train._create_label_link_pred( graph_train, edges_train ) graph_val._create_label_link_pred( graph_val, edges_val ) if split_num == 3: graph_test._create_label_link_pred( graph_test, edges_test ) split_graphs[0][i] = graph_train split_graphs[1][i] = graph_val split_graphs[2][i] = graph_test else: for graphs in split_graphs: for graph in graphs: if self.task == "node": graph.node_label_index = graph.custom_split_index if self.task == "edge": graph.edge_label_index = ( graph._edge_to_index(graph.custom_split_index) ) elif self.general_split_mode == "random": split_graphs = [] for graph in self.graphs: if isinstance(graph, Graph): if isinstance(graph, HeteroGraph): split_graph = graph.split( task=self.task, split_types=split_types, split_ratio=split_ratio, edge_split_mode=self.edge_split_mode ) else: split_graph = graph.split(self.task, split_ratio) else: raise TypeError( "element in self.graphs of unexpected type" ) split_graphs.append(split_graph) split_graphs = list(map(list, zip(*split_graphs))) for i, graph in enumerate(split_graphs[0]): if ( self.task == "link_pred" and self.edge_train_mode == "disjoint" ): if isinstance(graph, Graph): if isinstance(graph, HeteroGraph): graph = graph.split_link_pred( split_types=split_types, split_ratio=self.edge_message_ratio, edge_split_mode=self.edge_split_mode )[1] else: graph = graph.split_link_pred( self.edge_message_ratio )[1] split_graphs[0][i] = graph else: raise TypeError( "element in self.graphs of unexpected type" ) # list of num_splits datasets # (e.g. [train dataset, val dataset, test dataset]) dataset_return = [] for x in split_graphs: dataset_current = copy.copy(self) dataset_current.graphs = x if self.task == "link_pred": for graph_temp in dataset_current.graphs: if isinstance(graph_temp, Graph): if isinstance(graph_temp, HeteroGraph): graph_temp._create_neg_sampling( negative_sampling_ratio=( self.edge_negative_sampling_ratio ), split_types=split_types ) else: graph_temp._create_neg_sampling( self.edge_negative_sampling_ratio ) else: raise TypeError( "element in self.graphs of unexpected type" ) dataset_return.append(dataset_current) # resample negatives for train split (only for link prediction) dataset_return[0]._resample_negatives = True return dataset_return