Exemplo n.º 1
0
def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
                       col: Tensor, index: Tensor,
                       perm: OptTensor = None) -> EdgeStorage:
    # Filters a edge storage object to only hold the edges in `index`,
    # which represents the new graph as denoted by `(row, col)`:
    for key, value in store.items():
        if key == 'edge_index':
            out_store.edge_index = torch.stack([row, col], dim=0)

        elif key == 'adj_t':
            # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
            edge_attr = value.storage.value()
            edge_attr = None if edge_attr is None else edge_attr[index]
            sparse_sizes = store.size()[::-1]
            out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr,
                                           sparse_sizes=sparse_sizes,
                                           is_sorted=True)

        elif store.is_edge_attr(key):
            if perm is None:
                out_store[key] = index_select(value, index, dim=0)
            else:
                out_store[key] = index_select(value, perm[index], dim=0)

    return store
Exemplo n.º 2
0
    def _split(self, store: EdgeStorage, index: Tensor, is_undirected: bool,
               rev_edge_type: EdgeType):

        for key, value in store.items():
            if key == 'edge_index':
                continue

            if store.is_edge_attr(key):
                value = value[index]
                if is_undirected:
                    value = torch.cat([value, value], dim=0)
                store[key] = value

        edge_index = store.edge_index[:, index]
        if is_undirected:
            edge_index = torch.cat([edge_index, edge_index.flip([0])], dim=-1)
        store.edge_index = edge_index

        if rev_edge_type is not None:
            rev_store = store._parent()[rev_edge_type]
            for key in rev_store.keys():
                if key not in store:
                    del rev_store[key]  # We delete all outdated attributes.
                elif key == 'edge_index':
                    rev_store.edge_index = store.edge_index.flip([0])
                else:
                    rev_store[key] = store[key]

        return store
Exemplo n.º 3
0
    def get_edge_store(self, src: str, rel: str, dst: str) -> EdgeStorage:
        r"""Gets the :class:`~torch_geometric.data.storage.EdgeStorage` object
        of a particular edge type given by the tuple :obj:`(src, rel, dst)`.
        If the storage is not present yet, will create a new
        :class:`torch_geometric.data.storage.EdgeStorage` object for the given
        edge type.

        .. code-block:: python

            data = HeteroData()
            edge_storage = data.get_edge_store('author', 'writes', 'paper')
        """
        key = (src, rel, dst)
        out = self._edge_store_dict.get(key, None)
        if out is None:
            out = EdgeStorage(_parent=self, _key=key)
            self._edge_store_dict[key] = out
        return out
Exemplo n.º 4
0
def filter_edge_store_(store: EdgeStorage,
                       out_store: EdgeStorage,
                       row: Tensor,
                       col: Tensor,
                       index: Tensor,
                       perm: OptTensor = None) -> EdgeStorage:
    # Filters a edge storage object to only hold the edges in `index`,
    # which represents the new graph as denoted by `(row, col)`:
    for key, value in store.items():
        if key == 'edge_index':
            edge_index = torch.stack([row, col], dim=0)
            out_store.edge_index = edge_index.to(value.device)

        elif key == 'adj_t':
            # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
            row = row.to(value.device())
            col = col.to(value.device())
            edge_attr = value.storage.value()
            if edge_attr is not None:
                index = index.to(edge_attr.device)
                edge_attr = edge_attr[index]
            sparse_sizes = out_store.size()[::-1]
            # TODO Currently, we set `is_sorted=False`, see:
            # https://github.com/pyg-team/pytorch_geometric/issues/4346
            out_store.adj_t = SparseTensor(row=col,
                                           col=row,
                                           value=edge_attr,
                                           sparse_sizes=sparse_sizes,
                                           is_sorted=False,
                                           trust_data=True)

        elif store.is_edge_attr(key):
            dim = store._parent().__cat_dim__(key, value, store)
            if perm is None:
                index = index.to(value.device)
                out_store[key] = index_select(value, index, dim=dim)
            else:
                perm = perm.to(value.device)
                index = index.to(value.device)
                out_store[key] = index_select(value, perm[index], dim=dim)

    return store