def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor, col: Tensor, index: Tensor, perm: OptTensor = None) -> EdgeStorage: # Filters a edge storage object to only hold the edges in `index`, # which represents the new graph as denoted by `(row, col)`: for key, value in store.items(): if key == 'edge_index': out_store.edge_index = torch.stack([row, col], dim=0) elif key == 'adj_t': # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout). edge_attr = value.storage.value() edge_attr = None if edge_attr is None else edge_attr[index] sparse_sizes = store.size()[::-1] out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=True) elif store.is_edge_attr(key): if perm is None: out_store[key] = index_select(value, index, dim=0) else: out_store[key] = index_select(value, perm[index], dim=0) return store
def _split(self, store: EdgeStorage, index: Tensor, is_undirected: bool, rev_edge_type: EdgeType): for key, value in store.items(): if key == 'edge_index': continue if store.is_edge_attr(key): value = value[index] if is_undirected: value = torch.cat([value, value], dim=0) store[key] = value edge_index = store.edge_index[:, index] if is_undirected: edge_index = torch.cat([edge_index, edge_index.flip([0])], dim=-1) store.edge_index = edge_index if rev_edge_type is not None: rev_store = store._parent()[rev_edge_type] for key in rev_store.keys(): if key not in store: del rev_store[key] # We delete all outdated attributes. elif key == 'edge_index': rev_store.edge_index = store.edge_index.flip([0]) else: rev_store[key] = store[key] return store
def get_edge_store(self, src: str, rel: str, dst: str) -> EdgeStorage: r"""Gets the :class:`~torch_geometric.data.storage.EdgeStorage` object of a particular edge type given by the tuple :obj:`(src, rel, dst)`. If the storage is not present yet, will create a new :class:`torch_geometric.data.storage.EdgeStorage` object for the given edge type. .. code-block:: python data = HeteroData() edge_storage = data.get_edge_store('author', 'writes', 'paper') """ key = (src, rel, dst) out = self._edge_store_dict.get(key, None) if out is None: out = EdgeStorage(_parent=self, _key=key) self._edge_store_dict[key] = out return out
def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor, col: Tensor, index: Tensor, perm: OptTensor = None) -> EdgeStorage: # Filters a edge storage object to only hold the edges in `index`, # which represents the new graph as denoted by `(row, col)`: for key, value in store.items(): if key == 'edge_index': edge_index = torch.stack([row, col], dim=0) out_store.edge_index = edge_index.to(value.device) elif key == 'adj_t': # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout). row = row.to(value.device()) col = col.to(value.device()) edge_attr = value.storage.value() if edge_attr is not None: index = index.to(edge_attr.device) edge_attr = edge_attr[index] sparse_sizes = out_store.size()[::-1] # TODO Currently, we set `is_sorted=False`, see: # https://github.com/pyg-team/pytorch_geometric/issues/4346 out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=False, trust_data=True) elif store.is_edge_attr(key): dim = store._parent().__cat_dim__(key, value, store) if perm is None: index = index.to(value.device) out_store[key] = index_select(value, index, dim=dim) else: perm = perm.to(value.device) index = index.to(value.device) out_store[key] = index_select(value, perm[index], dim=dim) return store