Exemple #1
0
    def __init__(
        self,
        data: HeteroData,
        num_samples: Union[List[int], Dict[NodeType, List[int]]],
        input_nodes: Union[NodeType, Tuple[NodeType, Optional[Tensor]]],
        transform: Callable = None,
        **kwargs,
    ):
        if kwargs.get('num_workers', 0) > 0:
            torch.multiprocessing.set_sharing_strategy('file_system')
            kwargs['persistent_workers'] = True

        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']

        if isinstance(num_samples, (list, tuple)):
            num_samples = {key: num_samples for key in data.node_types}

        if isinstance(input_nodes, str):
            input_nodes = (input_nodes, None)
        assert isinstance(input_nodes, (list, tuple))
        assert len(input_nodes) == 2
        assert isinstance(input_nodes[0], str)
        if input_nodes[1] is None:
            index = torch.arange(data[input_nodes[0]].num_nodes)
            input_nodes = (input_nodes[0], index)
        elif input_nodes[1].dtype == torch.bool:
            index = input_nodes[1].nonzero(as_tuple=False).view(-1)
            input_nodes = (input_nodes[0], index)

        self.data = data
        self.num_samples = num_samples
        self.input_nodes = input_nodes
        self.num_hops = max([len(v) for v in num_samples.values()])
        self.transform = transform
        self.sample_fn = torch.ops.torch_sparse.hgt_sample

        # Convert the graph data into a suitable format for sampling.
        # NOTE: Since C++ cannot take dictionaries with tuples as key as
        # input, edge type triplets are converted into single strings.
        self.colptr_dict, self.row_dict, self.perm_dict = to_hetero_csc(data)

        super().__init__(input_nodes[1].tolist(),
                         collate_fn=self.sample,
                         **kwargs)
Exemple #2
0
    def __init__(
        self,
        data: Union[Data, HeteroData],
        num_neighbors: NumNeighbors,
        replace: bool = False,
        directed: bool = True,
        transform: Callable = None,
        input_node_type: Optional[str] = None,
    ):
        self.data = data
        self.num_neighbors = num_neighbors
        self.replace = replace
        self.directed = directed
        self.transform = transform

        if isinstance(data, Data):
            # Convert the graph data into a suitable format for sampling.
            self.colptr, self.row, self.perm = to_csc(data)
            assert isinstance(num_neighbors, (list, tuple))

        elif isinstance(data, HeteroData):
            # Convert the graph data into a suitable format for sampling.
            # NOTE: Since C++ cannot take dictionaries with tuples as key as
            # input, edge type triplets are converted into single strings.
            out = to_hetero_csc(data)
            self.colptr_dict, self.row_dict, self.perm_dict = out

            self.node_types, self.edge_types = data.metadata()
            if isinstance(num_neighbors, (list, tuple)):
                num_neighbors = {key: num_neighbors for key in self.edge_types}
            assert isinstance(num_neighbors, dict)
            self.num_neighbors = {
                edge_type_to_str(key): value
                for key, value in num_neighbors.items()
            }

            self.num_hops = max([len(v) for v in self.num_neighbors.values()])

            assert isinstance(input_node_type, str)
            self.input_node_type = input_node_type

        else:
            raise TypeError(f'NeighborLoader found invalid type: {type(data)}')
Exemple #3
0
    def __init__(
        self,
        data: Union[Data, HeteroData],
        num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
        input_nodes: Union[Optional[Tensor], NodeType,
                           Tuple[NodeType, Optional[Tensor]]] = None,
        replace: bool = False,
        directed: bool = True,
        transform: Callable = None,
        **kwargs,
    ):
        if kwargs.get('num_workers', 0) > 0:
            torch.multiprocessing.set_sharing_strategy('file_system')
            kwargs['persistent_workers'] = True

        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']
        if 'dataset' in kwargs:
            del kwargs['dataset']

        self.data = data
        self.num_neighbors = num_neighbors
        self.input_nodes = input_nodes
        self.replace = replace
        self.directed = directed
        self.transform = transform

        if isinstance(data, Data):
            self.sample_fn = torch.ops.torch_sparse.neighbor_sample
            # Convert the graph data into a suitable format for sampling.
            self.colptr, self.row, self.perm = to_csc(data)
            assert isinstance(num_neighbors, (list, tuple))
            assert input_nodes is None or isinstance(input_nodes, Tensor)
            if input_nodes is None:
                self.input_nodes = torch.arange(data.num_nodes)
            elif input_nodes.dtype == torch.bool:
                self.input_nodes = input_nodes.nonzero(as_tuple=False).view(-1)
            super().__init__(self.input_nodes.tolist(),
                             collate_fn=self.sample,
                             **kwargs)

        else:  # `HeteroData`:
            self.node_types, self.edge_types = data.metadata()
            self.sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample
            # Convert the graph data into a suitable format for sampling.
            # NOTE: Since C++ cannot take dictionaries with tuples as key as
            # input, edge type triplets are converted into single strings.
            out = to_hetero_csc(data)
            self.colptr_dict, self.row_dict, self.perm_dict = out
            if isinstance(num_neighbors, (list, tuple)):
                self.num_neighbors = {
                    key: num_neighbors
                    for key in self.edge_types
                }
            self.num_neighbors = {
                edge_type_to_str(key): value
                for key, value in self.num_neighbors.items()
            }
            self.num_hops = max([len(v) for v in self.num_neighbors.values()])
            if isinstance(input_nodes, str):
                self.input_nodes = (input_nodes, None)
            assert isinstance(self.input_nodes, (list, tuple))
            assert len(self.input_nodes) == 2
            assert isinstance(self.input_nodes[0], str)
            if self.input_nodes[1] is None:
                index = torch.arange(data[self.input_nodes[0]].num_nodes)
                self.input_nodes = (self.input_nodes[0], index)
            elif self.input_nodes[1].dtype == torch.bool:
                index = self.input_nodes[1].nonzero(as_tuple=False).view(-1)
                self.input_nodes = (self.input_nodes[0], index)
            super().__init__(self.input_nodes[1].tolist(),
                             collate_fn=self.hetero_sample,
                             **kwargs)
Exemple #4
0
    def __init__(
        self,
        data: Union[Data, HeteroData],
        num_neighbors: NumNeighbors,
        replace: bool = False,
        directed: bool = True,
        input_type: Optional[Any] = None,
        time_attr: Optional[str] = None,
        is_sorted: bool = False,
        share_memory: bool = False,
    ):
        self.data_cls = data.__class__
        self.num_neighbors = num_neighbors
        self.replace = replace
        self.directed = directed
        self.node_time = None

        if isinstance(data, Data):
            if time_attr is not None:
                # TODO `time_attr` support for homogeneous graphs
                raise ValueError(
                    f"'time_attr' attribute not yet supported for "
                    f"'{data.__class__.__name__}' object")

            # Convert the graph data into a suitable format for sampling.
            out = to_csc(data,
                         device='cpu',
                         share_memory=share_memory,
                         is_sorted=is_sorted)
            self.colptr, self.row, self.perm = out
            assert isinstance(num_neighbors, (list, tuple))

        elif isinstance(data, HeteroData):
            if time_attr is not None:
                self.node_time_dict = data.collect(time_attr)
            else:
                self.node_time_dict = None

            # Convert the graph data into a suitable format for sampling.
            # NOTE: Since C++ cannot take dictionaries with tuples as key as
            # input, edge type triplets are converted into single strings.
            out = to_hetero_csc(data,
                                device='cpu',
                                share_memory=share_memory,
                                is_sorted=is_sorted)
            self.colptr_dict, self.row_dict, self.perm_dict = out

            self.node_types, self.edge_types = data.metadata()
            if isinstance(num_neighbors, (list, tuple)):
                num_neighbors = {key: num_neighbors for key in self.edge_types}
            assert isinstance(num_neighbors, dict)
            self.num_neighbors = {
                edge_type_to_str(key): value
                for key, value in num_neighbors.items()
            }

            self.num_hops = max([len(v) for v in self.num_neighbors.values()])

            assert input_type is not None
            self.input_type = input_type

        else:
            raise TypeError(f'NeighborLoader found invalid type: {type(data)}')
    def __init__(
        self,
        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
        num_neighbors: NumNeighbors,
        replace: bool = False,
        directed: bool = True,
        input_type: Optional[Any] = None,
        time_attr: Optional[str] = None,
        is_sorted: bool = False,
        share_memory: bool = False,
    ):
        self.data_cls = data.__class__ if isinstance(
            data, (Data, HeteroData)) else 'custom'
        self.num_neighbors = num_neighbors
        self.replace = replace
        self.directed = directed
        self.node_time = None

        # TODO Unify the following conditionals behind the `FeatureStore`
        # and `GraphStore` API

        # If we are working with a `Data` object, convert the edge_index to
        # CSC and store it:
        if isinstance(data, Data):
            if time_attr is not None:
                # TODO `time_attr` support for homogeneous graphs
                raise ValueError(
                    f"'time_attr' attribute not yet supported for "
                    f"'{data.__class__.__name__}' object")

            # Convert the graph data into a suitable format for sampling.
            out = to_csc(data, device='cpu', share_memory=share_memory,
                         is_sorted=is_sorted)
            self.colptr, self.row, self.perm = out
            assert isinstance(num_neighbors, (list, tuple))

        # If we are working with a `HeteroData` object, convert each edge
        # type's edge_index to CSC and store it:
        elif isinstance(data, HeteroData):
            if time_attr is not None:
                self.node_time_dict = data.collect(time_attr)
            else:
                self.node_time_dict = None

            # Convert the graph data into a suitable format for sampling.
            # NOTE: Since C++ cannot take dictionaries with tuples as key as
            # input, edge type triplets are converted into single strings.
            out = to_hetero_csc(data, device='cpu', share_memory=share_memory,
                                is_sorted=is_sorted)
            self.colptr_dict, self.row_dict, self.perm_dict = out

            self.node_types, self.edge_types = data.metadata()
            self._set_num_neighbors_and_num_hops(num_neighbors)

            assert input_type is not None
            self.input_type = input_type

        # If we are working with a `Tuple[FeatureStore, GraphStore]` object,
        # obtain edges from GraphStore and convert them to CSC if necessary,
        # storing the resulting representations:
        elif isinstance(data, tuple):
            # TODO support `FeatureStore` with no edge types (e.g. `Data`)
            feature_store, graph_store = data

            # TODO support `collect` on `FeatureStore`
            self.node_time_dict = None
            if time_attr is not None:
                # We need to obtain all features with 'attr_name=time_attr'
                # from the feature store and store them in node_time_dict. To
                # do so, we make an explicit feature store GET call here with
                # the relevant 'TensorAttr's
                time_attrs = [
                    attr for attr in feature_store.get_all_tensor_attrs()
                    if attr.attr_name == time_attr
                ]
                for attr in time_attrs:
                    attr.index = None
                time_tensors = feature_store.multi_get_tensor(time_attrs)
                self.node_time_dict = {
                    time_attr.group_name: time_tensor
                    for time_attr, time_tensor in zip(time_attrs, time_tensors)
                }

            # Obtain all node and edge metadata:
            node_attrs = feature_store.get_all_tensor_attrs()
            edge_attrs = graph_store.get_all_edge_attrs()

            self.node_types = list(
                set(node_attr.group_name for node_attr in node_attrs))
            self.edge_types = list(
                set(edge_attr.edge_type for edge_attr in edge_attrs))

            # Set other required parameters:
            self._set_num_neighbors_and_num_hops(num_neighbors)

            assert input_type is not None
            self.input_type = input_type

            # Obtain CSC representations for in-memory sampling:
            row_dict, colptr_dict, perm_dict = graph_store.csc()
            self.row_dict = {
                edge_type_to_str(k): v
                for k, v in row_dict.items()
            }
            self.colptr_dict = {
                edge_type_to_str(k): v
                for k, v in colptr_dict.items()
            }
            self.perm_dict = {
                edge_type_to_str(k): v
                for k, v in perm_dict.items()
            }

        else:
            raise TypeError(f'NeighborLoader found invalid type: {type(data)}')