def __init__(self,
                 inner_edge_index_i: OptTensor = None,
                 x_i: OptTensor = None,
                 outer_edge_index_i: OptTensor = None,
                 inner_edge_index_j: OptTensor = None,
                 x_j: OptTensor = None,
                 outer_edge_index_j: OptTensor = None,
                 **kwargs):
        super(BipartitePairData, self).__init__(**kwargs)

        self.__dict__['_store'] = GlobalStorage(_parent=self)

        self.x_i = x_i
        self.x_j = x_j

        # self.num_nodes = int(x_i.size(0) + x_j.size(0))

        self.inner_edge_index_i = inner_edge_index_i
        self.inner_edge_index_j = inner_edge_index_j

        self.outer_edge_index_i = outer_edge_index_i
        self.outer_edge_index_j = outer_edge_index_j

        for key, value in kwargs.items():
            setattr(self, key, value)
示例#2
0
 def __init__(self,
              x: OptTensor = None,
              edge_index: OptTensor = None,
              edge_attr: OptTensor = None,
              y: OptTensor = None,
              pos: OptTensor = None,
              **kwargs):
     super().__init__()
     self._store = GlobalStorage(_parent=self)
     self.x = x
     self.edge_index = edge_index
     self.edge_attr = edge_attr
     self.y = y
     self.pos = pos
     for key, value in kwargs.items():
         setattr(self, key, value)
示例#3
0
    def __init__(self,
                 x: OptTensor = None,
                 edge_index: OptTensor = None,
                 edge_attr: OptTensor = None,
                 y: OptTensor = None,
                 pos: OptTensor = None,
                 **kwargs):
        # `Data` doesn't support group_name, so we need to adjust `TensorAttr`
        # accordingly here to avoid requiring `group_name` to be set:
        super().__init__(tensor_attr_cls=DataTensorAttr)

        # `Data` doesn't support edge_type, so we need to adjust `EdgeAttr`
        # accordingly here to avoid requiring `edge_type` to be set:
        GraphStore.__init__(self, edge_attr_cls=DataEdgeAttr)

        self.__dict__['_store'] = GlobalStorage(_parent=self)

        if x is not None:
            self.x = x
        if edge_index is not None:
            self.edge_index = edge_index
        if edge_attr is not None:
            self.edge_attr = edge_attr
        if y is not None:
            self.y = y
        if pos is not None:
            self.pos = pos

        for key, value in kwargs.items():
            setattr(self, key, value)
示例#4
0
    def __init__(
        self,
        src: Optional[Tensor] = None,
        dst: Optional[Tensor] = None,
        t: Optional[Tensor] = None,
        msg: Optional[Tensor] = None,
        **kwargs,
    ):
        super().__init__()
        self.__dict__['_store'] = GlobalStorage(_parent=self)

        self.src = src
        self.dst = dst
        self.t = t
        self.msg = msg

        for key, value in kwargs.items():
            setattr(self, key, value)
示例#5
0
    def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
                 edge_attr: OptTensor = None, y: OptTensor = None,
                 pos: OptTensor = None, **kwargs):
        super().__init__()
        self.__dict__['_store'] = GlobalStorage(_parent=self)

        if x is not None:
            self.x = x
        if edge_index is not None:
            self.edge_index = edge_index
        if edge_attr is not None:
            self.edge_attr = edge_attr
        if y is not None:
            self.y = y
        if pos is not None:
            self.pos = pos

        for key, value in kwargs.items():
            setattr(self, key, value)
示例#6
0
class Data(BaseData):
    r"""A plain old Python object modeling a single graph with various
    (optional) attributes:

    Args:
        x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
            num_node_features]`. (default: :obj:`None`)
        edge_index (LongTensor, optional): Graph connectivity in COO format
            with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
        edge_attr (Tensor, optional): Edge feature matrix with shape
            :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
        y (Tensor, optional): Graph-level or node-level ground-truth labels
            with arbitrary shape. (default: :obj:`None`)
        pos (Tensor, optional): Node position matrix with shape
            :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)

    The data object is not restricted to these attributes and can be extented
    by any other additional data.

    Example::

        data = Data(x=x, edge_index=edge_index)
        data.train_idx = torch.tensor([...], dtype=torch.long)
        data.test_mask = torch.tensor([...], dtype=torch.bool)
    """
    def __init__(self,
                 x: OptTensor = None,
                 edge_index: OptTensor = None,
                 edge_attr: OptTensor = None,
                 y: OptTensor = None,
                 pos: OptTensor = None,
                 **kwargs):
        super().__init__()
        self._store = GlobalStorage(_parent=self)
        self.x = x
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.y = y
        self.pos = pos
        for key, value in kwargs.items():
            setattr(self, key, value)

    def __getattr__(self, key: str) -> Any:
        return getattr(self._store, key)

    def __setattr__(self, key: str, value: Any):
        # `data._* = ...` => Link to the private `__dict__` store.
        # `data.* = ...` => Link to the `_store`.
        if key[:1] == '_':
            self.__dict__[key] = value
        else:
            setattr(self._store, key, value)

    def __delattr__(self, key: str):
        # `del data._*` => Link to the private `__dict__` store.
        # `del data.*` => Link to the `_store`.
        if key[:1] == '_':
            del self.__dict__[key]
        else:
            delattr(self._store, key)

    def __getitem__(self, key: str) -> Any:
        r"""Gets the data of the attribute :obj:`key`."""
        return self._store[key]

    def __setitem__(self, key: str, value: Any):
        r"""Sets the attribute :obj:`key` to :obj:`value`."""
        self._store[key] = value

    def __delitem__(self, key: str):
        del self.store[key]

    def __copy__(self):
        out = self.__class__()
        for key, value in self.__dict__.items():
            if key not in ['_store']:
                out.__dict__[key] = value
        out._store = copy.copy(self._store)
        out._store._parent = out
        return out

    def __deepcopy__(self, memo):
        out = self.__class__()
        for key, value in self.__dict__.items():
            out.__dict__[key] = copy.deepcopy(value, memo)
        out._store._parent = out
        return out

    def __repr__(self) -> str:
        cls = self.__class__.__name__
        has_dict = any([isinstance(v, Mapping) for v in self._store.values()])

        if not has_dict:
            info = [size_repr(k, v) for k, v in self._store.items()]
            return '{}({})'.format(cls, ', '.join(info))
        else:
            info = [size_repr(k, v, indent=2) for k, v in self._store.items()]
            return '{}(\n{}\n)'.format(cls, ',\n'.join(info))

    @property
    def stores(self) -> List[BaseStorage]:
        return [self._store]

    @property
    def node_stores(self) -> List[NodeStorage]:
        return [self._store]

    @property
    def edge_stores(self) -> List[EdgeStorage]:
        return [self._store]

    def to_dict(self) -> Dict[str, Any]:
        r"""Returns a dictionary of stored key/value pairs."""
        return self._store.to_dict()

    def to_namedtuple(self) -> NamedTuple:
        r"""Returns a :obj:`NamedTuple` of stored key/value pairs."""
        return self._store.to_namedtuple()

    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
        r"""Returns the dimension for which :obj:`value` of attribute
        :obj:`key` will get concatenated when creating mini-batches.

        .. note::

            This method is for internal use only, and should only be overridden
            in case the mini-batch creation process is corrupted for a specific
            attribute.
        """
        if isinstance(value, SparseTensor):
            return (0, 1)
        elif 'index' in key or 'face' in key:
            return -1
        else:
            return 0

    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
        r"""Returns the incremental count to cumulatively increase the value
        of the next attribute of :obj:`key` when creating mini-batches.

        .. note::

            This method is for internal use only, and should only be overridden
            in case the mini-batch creation process is corrupted for a specific
            attribute.
        """
        if 'index' in key or 'face' in key:
            return self.num_nodes
        else:
            return 0

    def debug(self):
        pass  # TODO

    ###########################################################################

    @classmethod
    def from_dict(cls, mapping: Dict[str, Any]):
        r"""Creates a :class:`~torch_geometric.data.Data` object from a Python
        dictionary."""
        return cls(**mapping)

    @property
    def num_node_features(self) -> int:
        r"""Returns the number of features per node in the graph."""
        return self._store.num_node_features

    @property
    def num_features(self):
        r"""Alias for :py:attr:`~num_node_features`."""
        return self._store.num_features

    @property
    def num_edge_features(self) -> int:
        r"""Returns the number of features per edge in the graph."""
        return self._store.num_edge_features

    def __iter__(self) -> Iterable:
        r"""Iterates over all attributes in the data, yielding their attribute
        names and values."""
        for key, value in self._store.items():
            yield key, value

    def __call__(self, *args: List[str]) -> Iterable:
        r"""Iterates over all attributes :obj:`*args` in the data, yielding
        their attribute names and values.
        If :obj:`*args` is not given, will iterate over all attributes."""
        for key, value in self._store.items(*args):
            yield key, value

    @property
    def x(self) -> Any:
        return self['x'] if 'x' in self._store else None

    @property
    def edge_index(self) -> Any:
        return self['edge_index'] if 'edge_index' in self._store else None

    @property
    def edge_attr(self) -> Any:
        return self['edge_attr'] if 'edge_attr' in self._store else None

    @property
    def y(self) -> Any:
        return self['y'] if 'y' in self._store else None

    @property
    def pos(self) -> Any:
        return self['pos'] if 'pos' in self._store else None

    # Deprecated functions ####################################################

    @property
    @deprecated(details="use 'data.face.size(-1)' instead")
    def num_faces(self) -> Optional[int]:
        r"""Returns the number of faces in the mesh."""
        if 'face' in self._store:
            return self.face.size(self.__cat_dim__('face', self.face))
        return None