示例#1
0
    def __init__(self, neighbor_sizes, edge_index_dict, num_nodes_dict,
                 node_types, head_node_type):
        """
        Args:
            neighbor_sizes:
            edge_index_dict:
            num_nodes_dict:
            node_types:
            head_node_type:
        """
        self.head_node_type = head_node_type

        # Ensure head_node_type is first item in num_nodes_dict, since NeighborSampler.sample() function takes in index only the first
        num_nodes_dict = OrderedDict([(node_type, num_nodes_dict[node_type])
                                      for node_type in node_types])

        self.edge_index, self.edge_type, self.node_type, self.local_node_idx, self.local2global, self.key2int = \
            group_hetero_graph(edge_index_dict, num_nodes_dict)

        self.int2node_type = {
            type_int: node_type
            for node_type, type_int in self.key2int.items()
            if node_type in node_types
        }
        self.int2edge_type = {
            type_int: edge_type
            for edge_type, type_int in self.key2int.items()
            if edge_type in edge_index_dict
        }

        self.neighbor_sampler = HeteroNeighborSampler(self.edge_index,
                                                      node_idx=None,
                                                      sizes=neighbor_sizes,
                                                      batch_size=128,
                                                      shuffle=True)
示例#2
0
    def __init__(self,
                 dataset,
                 neighbor_sizes,
                 node_types=None,
                 metapaths=None,
                 head_node_type=None,
                 directed=True,
                 resample_train=None,
                 add_reverse_metapaths=True,
                 inductive=False):
        self.neighbor_sizes = neighbor_sizes
        super(HeteroNeighborSampler,
              self).__init__(dataset, node_types, metapaths, head_node_type,
                             directed, resample_train, add_reverse_metapaths,
                             inductive)

        if self.use_reverse:
            self.add_reverse_edge_index(self.edge_index_dict)

        # Ensure head_node_type is first item in num_nodes_dict, since NeighborSampler.sample() function takes in index only the first
        num_nodes_dict = OrderedDict([(node_type,
                                       self.num_nodes_dict[node_type])
                                      for node_type in self.node_types])

        self.edge_index, self.edge_type, self.node_type, self.local_node_idx, self.local2global, self.key2int = \
            group_hetero_graph(self.edge_index_dict, num_nodes_dict)

        self.int2node_type = {
            type_int: node_type
            for node_type, type_int in self.key2int.items()
            if node_type in self.node_types
        }
        self.int2edge_type = {
            type_int: edge_type
            for edge_type, type_int in self.key2int.items()
            if edge_type in self.edge_index_dict
        }

        self.neighbor_sampler = NeighborSampler(self.edge_index,
                                                node_idx=self.training_idx,
                                                sizes=self.neighbor_sizes,
                                                batch_size=128,
                                                shuffle=True)
示例#3
0
# Convert to undirected paper <-> paper relation.
edge_index = to_undirected(edge_index_dict[("paper", "cites", "paper")])
edge_index_dict[("paper", "cites", "paper")] = edge_index

# We convert the individual graphs into a single big one, so that sampling
# neighbors does not need to care about different edge types.
# This will return the following:
# * `edge_index`: The new global edge connectivity.
# * `edge_type`: The edge type for each edge.
# * `node_type`: The node type for each node.
# * `local_node_idx`: The original index for each node.
# * `local2global`: A dictionary mapping original (local) node indices of
#    type `key` to global ones.
# `key2int`: A dictionary that maps original keys to their new canonical type.
out = group_hetero_graph(data.edge_index_dict, data.num_nodes_dict)
edge_index, edge_type, node_type, local_node_idx, local2global, key2int = out

homo_data = Data(
    edge_index=edge_index,
    edge_attr=edge_type,
    node_type=node_type,
    local_node_idx=local_node_idx,
    num_nodes=node_type.size(0),
)

homo_data.y = node_type.new_full((node_type.size(0), 1), -1)
homo_data.y[local2global["paper"]] = data.y_dict["paper"]

homo_data.train_mask = torch.zeros((node_type.size(0)), dtype=torch.bool)
homo_data.train_mask[local2global["paper"][split_idx["train"]["paper"]]] = True