def __init__(self, inner_edge_index_i: OptTensor = None, x_i: OptTensor = None, outer_edge_index_i: OptTensor = None, inner_edge_index_j: OptTensor = None, x_j: OptTensor = None, outer_edge_index_j: OptTensor = None, **kwargs): super(BipartitePairData, self).__init__(**kwargs) self.__dict__['_store'] = GlobalStorage(_parent=self) self.x_i = x_i self.x_j = x_j # self.num_nodes = int(x_i.size(0) + x_j.size(0)) self.inner_edge_index_i = inner_edge_index_i self.inner_edge_index_j = inner_edge_index_j self.outer_edge_index_i = outer_edge_index_i self.outer_edge_index_j = outer_edge_index_j for key, value in kwargs.items(): setattr(self, key, value)
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs): super().__init__() self._store = GlobalStorage(_parent=self) self.x = x self.edge_index = edge_index self.edge_attr = edge_attr self.y = y self.pos = pos for key, value in kwargs.items(): setattr(self, key, value)
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs): # `Data` doesn't support group_name, so we need to adjust `TensorAttr` # accordingly here to avoid requiring `group_name` to be set: super().__init__(tensor_attr_cls=DataTensorAttr) # `Data` doesn't support edge_type, so we need to adjust `EdgeAttr` # accordingly here to avoid requiring `edge_type` to be set: GraphStore.__init__(self, edge_attr_cls=DataEdgeAttr) self.__dict__['_store'] = GlobalStorage(_parent=self) if x is not None: self.x = x if edge_index is not None: self.edge_index = edge_index if edge_attr is not None: self.edge_attr = edge_attr if y is not None: self.y = y if pos is not None: self.pos = pos for key, value in kwargs.items(): setattr(self, key, value)
def __init__( self, src: Optional[Tensor] = None, dst: Optional[Tensor] = None, t: Optional[Tensor] = None, msg: Optional[Tensor] = None, **kwargs, ): super().__init__() self.__dict__['_store'] = GlobalStorage(_parent=self) self.src = src self.dst = dst self.t = t self.msg = msg for key, value in kwargs.items(): setattr(self, key, value)
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs): super().__init__() self.__dict__['_store'] = GlobalStorage(_parent=self) if x is not None: self.x = x if edge_index is not None: self.edge_index = edge_index if edge_attr is not None: self.edge_attr = edge_attr if y is not None: self.y = y if pos is not None: self.pos = pos for key, value in kwargs.items(): setattr(self, key, value)
class Data(BaseData): r"""A plain old Python object modeling a single graph with various (optional) attributes: Args: x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) edge_index (LongTensor, optional): Graph connectivity in COO format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) edge_attr (Tensor, optional): Edge feature matrix with shape :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) y (Tensor, optional): Graph-level or node-level ground-truth labels with arbitrary shape. (default: :obj:`None`) pos (Tensor, optional): Node position matrix with shape :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) The data object is not restricted to these attributes and can be extented by any other additional data. Example:: data = Data(x=x, edge_index=edge_index) data.train_idx = torch.tensor([...], dtype=torch.long) data.test_mask = torch.tensor([...], dtype=torch.bool) """ def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs): super().__init__() self._store = GlobalStorage(_parent=self) self.x = x self.edge_index = edge_index self.edge_attr = edge_attr self.y = y self.pos = pos for key, value in kwargs.items(): setattr(self, key, value) def __getattr__(self, key: str) -> Any: return getattr(self._store, key) def __setattr__(self, key: str, value: Any): # `data._* = ...` => Link to the private `__dict__` store. # `data.* = ...` => Link to the `_store`. if key[:1] == '_': self.__dict__[key] = value else: setattr(self._store, key, value) def __delattr__(self, key: str): # `del data._*` => Link to the private `__dict__` store. # `del data.*` => Link to the `_store`. if key[:1] == '_': del self.__dict__[key] else: delattr(self._store, key) def __getitem__(self, key: str) -> Any: r"""Gets the data of the attribute :obj:`key`.""" return self._store[key] def __setitem__(self, key: str, value: Any): r"""Sets the attribute :obj:`key` to :obj:`value`.""" self._store[key] = value def __delitem__(self, key: str): del self.store[key] def __copy__(self): out = self.__class__() for key, value in self.__dict__.items(): if key not in ['_store']: out.__dict__[key] = value out._store = copy.copy(self._store) out._store._parent = out return out def __deepcopy__(self, memo): out = self.__class__() for key, value in self.__dict__.items(): out.__dict__[key] = copy.deepcopy(value, memo) out._store._parent = out return out def __repr__(self) -> str: cls = self.__class__.__name__ has_dict = any([isinstance(v, Mapping) for v in self._store.values()]) if not has_dict: info = [size_repr(k, v) for k, v in self._store.items()] return '{}({})'.format(cls, ', '.join(info)) else: info = [size_repr(k, v, indent=2) for k, v in self._store.items()] return '{}(\n{}\n)'.format(cls, ',\n'.join(info)) @property def stores(self) -> List[BaseStorage]: return [self._store] @property def node_stores(self) -> List[NodeStorage]: return [self._store] @property def edge_stores(self) -> List[EdgeStorage]: return [self._store] def to_dict(self) -> Dict[str, Any]: r"""Returns a dictionary of stored key/value pairs.""" return self._store.to_dict() def to_namedtuple(self) -> NamedTuple: r"""Returns a :obj:`NamedTuple` of stored key/value pairs.""" return self._store.to_namedtuple() def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: r"""Returns the dimension for which :obj:`value` of attribute :obj:`key` will get concatenated when creating mini-batches. .. note:: This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute. """ if isinstance(value, SparseTensor): return (0, 1) elif 'index' in key or 'face' in key: return -1 else: return 0 def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: r"""Returns the incremental count to cumulatively increase the value of the next attribute of :obj:`key` when creating mini-batches. .. note:: This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute. """ if 'index' in key or 'face' in key: return self.num_nodes else: return 0 def debug(self): pass # TODO ########################################################################### @classmethod def from_dict(cls, mapping: Dict[str, Any]): r"""Creates a :class:`~torch_geometric.data.Data` object from a Python dictionary.""" return cls(**mapping) @property def num_node_features(self) -> int: r"""Returns the number of features per node in the graph.""" return self._store.num_node_features @property def num_features(self): r"""Alias for :py:attr:`~num_node_features`.""" return self._store.num_features @property def num_edge_features(self) -> int: r"""Returns the number of features per edge in the graph.""" return self._store.num_edge_features def __iter__(self) -> Iterable: r"""Iterates over all attributes in the data, yielding their attribute names and values.""" for key, value in self._store.items(): yield key, value def __call__(self, *args: List[str]) -> Iterable: r"""Iterates over all attributes :obj:`*args` in the data, yielding their attribute names and values. If :obj:`*args` is not given, will iterate over all attributes.""" for key, value in self._store.items(*args): yield key, value @property def x(self) -> Any: return self['x'] if 'x' in self._store else None @property def edge_index(self) -> Any: return self['edge_index'] if 'edge_index' in self._store else None @property def edge_attr(self) -> Any: return self['edge_attr'] if 'edge_attr' in self._store else None @property def y(self) -> Any: return self['y'] if 'y' in self._store else None @property def pos(self) -> Any: return self['pos'] if 'pos' in self._store else None # Deprecated functions #################################################### @property @deprecated(details="use 'data.face.size(-1)' instead") def num_faces(self) -> Optional[int]: r"""Returns the number of faces in the mesh.""" if 'face' in self._store: return self.face.size(self.__cat_dim__('face', self.face)) return None