def __init__(self, neighbor_sizes, edge_index_dict, num_nodes_dict, node_types, head_node_type): """ Args: neighbor_sizes: edge_index_dict: num_nodes_dict: node_types: head_node_type: """ self.head_node_type = head_node_type # Ensure head_node_type is first item in num_nodes_dict, since NeighborSampler.sample() function takes in index only the first num_nodes_dict = OrderedDict([(node_type, num_nodes_dict[node_type]) for node_type in node_types]) self.edge_index, self.edge_type, self.node_type, self.local_node_idx, self.local2global, self.key2int = \ group_hetero_graph(edge_index_dict, num_nodes_dict) self.int2node_type = { type_int: node_type for node_type, type_int in self.key2int.items() if node_type in node_types } self.int2edge_type = { type_int: edge_type for edge_type, type_int in self.key2int.items() if edge_type in edge_index_dict } self.neighbor_sampler = HeteroNeighborSampler(self.edge_index, node_idx=None, sizes=neighbor_sizes, batch_size=128, shuffle=True)
def __init__(self, dataset, neighbor_sizes, node_types=None, metapaths=None, head_node_type=None, directed=True, resample_train=None, add_reverse_metapaths=True, inductive=False): self.neighbor_sizes = neighbor_sizes super(HeteroNeighborSampler, self).__init__(dataset, node_types, metapaths, head_node_type, directed, resample_train, add_reverse_metapaths, inductive) if self.use_reverse: self.add_reverse_edge_index(self.edge_index_dict) # Ensure head_node_type is first item in num_nodes_dict, since NeighborSampler.sample() function takes in index only the first num_nodes_dict = OrderedDict([(node_type, self.num_nodes_dict[node_type]) for node_type in self.node_types]) self.edge_index, self.edge_type, self.node_type, self.local_node_idx, self.local2global, self.key2int = \ group_hetero_graph(self.edge_index_dict, num_nodes_dict) self.int2node_type = { type_int: node_type for node_type, type_int in self.key2int.items() if node_type in self.node_types } self.int2edge_type = { type_int: edge_type for edge_type, type_int in self.key2int.items() if edge_type in self.edge_index_dict } self.neighbor_sampler = NeighborSampler(self.edge_index, node_idx=self.training_idx, sizes=self.neighbor_sizes, batch_size=128, shuffle=True)
# Convert to undirected paper <-> paper relation. edge_index = to_undirected(edge_index_dict[("paper", "cites", "paper")]) edge_index_dict[("paper", "cites", "paper")] = edge_index # We convert the individual graphs into a single big one, so that sampling # neighbors does not need to care about different edge types. # This will return the following: # * `edge_index`: The new global edge connectivity. # * `edge_type`: The edge type for each edge. # * `node_type`: The node type for each node. # * `local_node_idx`: The original index for each node. # * `local2global`: A dictionary mapping original (local) node indices of # type `key` to global ones. # `key2int`: A dictionary that maps original keys to their new canonical type. out = group_hetero_graph(data.edge_index_dict, data.num_nodes_dict) edge_index, edge_type, node_type, local_node_idx, local2global, key2int = out homo_data = Data( edge_index=edge_index, edge_attr=edge_type, node_type=node_type, local_node_idx=local_node_idx, num_nodes=node_type.size(0), ) homo_data.y = node_type.new_full((node_type.size(0), 1), -1) homo_data.y[local2global["paper"]] = data.y_dict["paper"] homo_data.train_mask = torch.zeros((node_type.size(0)), dtype=torch.bool) homo_data.train_mask[local2global["paper"][split_idx["train"]["paper"]]] = True