def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): self._global_store = BaseStorage(_parent=self) self._node_stores_dict = {} self._edge_stores_dict = {} for key, value in chain((_mapping or {}).items(), kwargs.items()): if '__' in key and isinstance(value, Mapping): key = tuple(key.split('__')) if isinstance(value, Mapping): self[key].update(value) else: setattr(self, key, value)
def test_base_storage(): storage = BaseStorage() storage.x = torch.zeros(1) storage.y = torch.ones(1) assert len(storage) == 2 assert storage.x is not None assert storage.y is not None assert len(list(storage.keys('x', 'y', 'z'))) == 2 assert len(list(storage.keys('x', 'y', 'z'))) == 2 assert len(list(storage.values('x', 'y', 'z'))) == 2 assert len(list(storage.items('x', 'y', 'z'))) == 2 del storage.y assert len(storage) == 1 assert storage.x is not None storage = BaseStorage({'x': torch.zeros(1)}) assert len(storage) == 1 assert storage.x is not None storage = BaseStorage(x=torch.zeros(1)) assert len(storage) == 1 assert storage.x is not None storage = BaseStorage(x=torch.zeros(1)) copied_storage = copy.copy(storage) assert storage == copied_storage assert id(storage) != id(copied_storage) assert storage.x.data_ptr() == copied_storage.x.data_ptr() assert int(storage.x) == 0 assert int(copied_storage.x) == 0 deepcopied_storage = copy.deepcopy(storage) assert storage == deepcopied_storage assert id(storage) != id(deepcopied_storage) assert storage.x.data_ptr() != deepcopied_storage.x.data_ptr() assert int(storage.x) == 0 assert int(deepcopied_storage.x) == 0
class HeteroData(BaseData): r"""A Python object modeling a single heterogeneous graph. There exists a few ways to create a heterogeneous graph data, *e.g.*: * To initialize a node of type :obj:`"paper"` holding a node feature matrix :obj:`x_paper` named :obj:`x`: .. code-block:: python data = HeteroData() data['paper'].x = x_paper data = HeteroData(paper={ 'x': x_paper }) data = HeteroData({'paper': { 'x': x_paper }}) * To initialize an edge from source node type :obj:`"author"` to destination node type :obj:`"paper"` with relation type :obj:`"writes"` holding a graph representation :obj:`edge_index_author_paper` named :obj:`edge_index`: .. code-block:: python data = HeteroData() data['author', 'writes', 'paper'].edge_index = edge_index_author_paper data = HeteroData(author__writes__paper={ 'edge_index': edge_index_author_paper }) data = HeteroData({ ('author', 'writes', 'paper'): { 'edge_index': edge_index_author_paper } }) """ def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): self._global_store = BaseStorage(_parent=self) self._node_stores_dict = {} self._edge_stores_dict = {} for key, value in chain((_mapping or {}).items(), kwargs.items()): if '__' in key and isinstance(value, Mapping): key = tuple(key.split('__')) if isinstance(value, Mapping): self[key].update(value) else: setattr(self, key, value) def __getattr__(self, key: str) -> Any: # `data.*_dict` => Link to node and edge stores. # `data.*` => Link to the `_global_store`. # It is the same as using `collect` to collect nodes and edges features # and use `attribute` to get graph attribute. if bool(re.search('_dict$', key)): out = self.collect(key[:-5]) if len(out) > 0: return out return getattr(self._global_store, key) def __setattr__(self, key: str, value: Any): """""" # `data._* = ...` => Link to the private `__dict__` store. # `data.* = ...` => Link to the `_global_store`. # NOTE: We aim to prevent duplicates in node or edge keys. if key[:1] == '_': self.__dict__[key] = value else: if key in self.node_types: raise AttributeError( f"'{key}' is already present as a node type") elif key in self.edge_types: raise AttributeError( f"'{key}' is already present as an edge type") setattr(self._global_store, key, value) def __delattr__(self, key: str): """""" # `del data._*` => Link to the private `__dict__` store. # `del data.*` => Link to the `_global_store`. if key[:1] == '_': del self.__dict__[key] else: delattr(self._global_store, key) def __getitem__(self, *args: Tuple[QueryType]) -> Any: # `data[*]` => Link to either `_global_store`, _node_stores_dict` or # `_edge_stores_dict`. # If neither is present, we create a new `Storage` object for the given # node/edge-type. key = self._to_canonical(*args) out = self._global_store.get(key, None) if out is not None: return out if isinstance(key, tuple): return self.get_edge_store(*key) else: return self.get_node_store(key) def __setitem__(self, key: str, value: Any): if key in chain(self.node_types, self.edge_types): raise AttributeError( f"'{key}' is already present as a node/edge-type") self._global_store[key] = value def __delitem__(self, *args: Tuple[QueryType]): # `del data[*]` => Link to `_node_stores_dict` or `_edge_stores_dict`. key = self._to_canonical(*args) if isinstance(key, tuple) and key in self.edge_types: del self._edge_stores_dict[key] elif key in self.node_types: del self._node_stores_dict[key] def __copy__(self): out = self.__class__() for key, value in self.__dict__.items(): out.__dict__[key] = value out._global_store = copy.copy(self._global_store) out._global_store._parent = out out._node_stores_dict = {} for key, store in self._node_stores_dict.items(): out._node_stores_dict[key] = copy.copy(store) out._node_stores_dict[key]._parent = out out._edge_stores_dict = {} for key, store in self._edge_stores_dict.items(): out._edge_stores_dict[key] = copy.copy(store) out._edge_stores_dict[key]._parent = out return out def __deepcopy__(self, memo): out = self.__class__() for key, value in self.__dict__.items(): if key not in ['_node_stores_dict', '_edge_stores_dict']: out.__dict__[key] = copy.deepcopy(value, memo) out._global_store._parent = out out._node_stores_dict = {} for key, store in self._node_stores_dict.items(): out._node_stores_dict[key] = copy.deepcopy(store, memo) out._node_stores_dict[key]._parent = out out._edge_stores_dict = {} for key, store in self._edge_stores_dict.items(): out._edge_stores_dict[key] = copy.deepcopy(store, memo) out._edge_stores_dict[key]._parent = out return out def __repr__(self) -> str: info1 = [size_repr(k, v, 2) for k, v in self._global_store.items()] info2 = [size_repr(k, v, 2) for k, v in self._node_stores_dict.items()] info3 = [size_repr(k, v, 2) for k, v in self._edge_stores_dict.items()] info = info1 + info2 + info3 return '{}(\n{}\n)'.format(self.__class__.__name__, ',\n'.join(info)) def _all_nodes_and_edges(self): # Returns all node storage items and edge storage items. return chain(self._node_stores_dict.items(), self._edge_stores_dict.items()) @property def stores(self) -> List[BaseStorage]: # Return a list of all storages of the graph. return ([self._global_store] + list(self.node_stores) + list(self.edge_stores)) @property def node_types(self) -> List[NodeType]: # Return a list of all node types of the graph. return list(self._node_stores_dict.keys()) @property def node_stores(self) -> List[NodeStorage]: # Return a list of all node storages of the graph. return list(self._node_stores_dict.values()) @property def edge_types(self) -> List[EdgeType]: # Return a list of all edge types of the graph. return list(self._edge_stores_dict.keys()) @property def edge_stores(self) -> List[EdgeStorage]: # Return a list of all edge storages of the graph. return list(self._edge_stores_dict.values()) def to_dict(self) -> Dict[str, Any]: out = self._global_store.to_dict() for key, store in self._all_nodes_and_edges(): out[key] = store.to_dict() return out def to_namedtuple(self) -> NamedTuple: field_names = list(self._global_store.keys()) field_values = list(self._global_store.values()) field_names += [ '__'.join(key) if isinstance(key, tuple) else key for key in self.node_types + self.edge_types ] field_values += [ store.to_namedtuple() for store in self.node_stores + self.edge_stores ] DataTuple = namedtuple('DataTuple', field_names) return DataTuple(*field_values) def __cat_dim__(self, key: str, value: Any, store: Optional[NodeOrEdgeStorage] = None) -> Any: if isinstance(value, SparseTensor): return (0, 1) elif 'index' in key or 'face' in key: return -1 return 0 def __inc__(self, key: str, value: Any, store: Optional[NodeOrEdgeStorage] = None) -> Any: if isinstance(store, EdgeStorage) and 'index' in key: return torch.tensor(store.size()).view(2, 1) else: return 0 def debug(self): pass # TODO ########################################################################### def _to_canonical(self, *args: Tuple[QueryType]) -> NodeOrEdgeType: # Converts a given `QueryType` to its "canonical type": # 1. `relation_type` will get mapped to the unique # `(src_node_type, relation_type, dst_node_type)` tuple. # 2. `(src_node_type, dst_node_type)` will get mapped to the unique # `(src_node_type, *, dst_node_type)` tuple, and # `(src_node_type, '_', dst_node_type)` otherwise. if len(args) == 1: args = args[0] if isinstance(args, str): # Try to map to edge type based on unique relation type: edge_types = [key for key in self.metadata()[1] if key[1] == args] if len(edge_types) == 1: args = edge_types[0] elif len(args) == 2: # Try to find the unique source/destination node tuple: edge_types = [ key for key in self.metadata()[1] if key[0] == args[0] and key[-1] == args[-1] ] if len(edge_types) == 1: args = edge_types[0] else: args = (args[0], '_', args[1]) return args def metadata(self) -> Tuple[List[NodeType], List[EdgeType]]: # Returns the heterogeneous meta-data, i.e. its node and edge types. return self.node_types, self.edge_types def collect(self, key: str) -> Dict[NodeOrEdgeType, Any]: r"""Collects the attribute :attr:`key` from all node and edge types.""" mapping = {} for subtype, store in self._all_nodes_and_edges(): if key in store: mapping[subtype] = store[key] return mapping def attribute(self, key: str) -> Any: # Get the attribute `key` from `_global_store`. return getattr(self._global_store, key) def get_node_store(self, key: NodeType) -> NodeStorage: r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object of a particular node type :attr:`key`. If the storage is not present yet, will create a new :class:`~torch_geometric.data.storage.NodeStorage` object for the given node type. .. code-block:: python data = HeteroData() node_storage = data.get_node_store('paper') """ out = self._node_stores_dict.get(key, None) if out is None: out = NodeStorage(_parent=self, _key=key) self._node_stores_dict[key] = out return out def get_edge_store(self, src: str, rel: str, dst: str) -> EdgeStorage: r"""Gets the :class:`~torch_geometric.data.storage.EdgeStorage` object of a particular edge type given by the tuple :obj:`(src, rel, dst)`. If the storage is not present yet, will create a new :class:`~torch_geometric.data.storage.EdgeStorage` object for the given edge type. .. code-block:: python data = HeteroData() edge_storage = data.get_edge_store('author', 'writes', 'paper') """ key = (src, rel, dst) out = self._edge_stores_dict.get(key, None) if out is None: out = EdgeStorage(_parent=self, _key=key) self._edge_stores_dict[key] = out return out