def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor, col: Tensor, index: Tensor, perm: OptTensor = None) -> EdgeStorage: # Filters a edge storage object to only hold the edges in `index`, # which represents the new graph as denoted by `(row, col)`: for key, value in store.items(): if key == 'edge_index': out_store.edge_index = torch.stack([row, col], dim=0) elif key == 'adj_t': # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout). edge_attr = value.storage.value() edge_attr = None if edge_attr is None else edge_attr[index] sparse_sizes = store.size()[::-1] out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=True) elif store.is_edge_attr(key): if perm is None: out_store[key] = index_select(value, index, dim=0) else: out_store[key] = index_select(value, perm[index], dim=0) return store
def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor, col: Tensor, index: Tensor, perm: OptTensor = None) -> EdgeStorage: # Filters a edge storage object to only hold the edges in `index`, # which represents the new graph as denoted by `(row, col)`: for key, value in store.items(): if key == 'edge_index': edge_index = torch.stack([row, col], dim=0) out_store.edge_index = edge_index.to(value.device) elif key == 'adj_t': # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout). row = row.to(value.device()) col = col.to(value.device()) edge_attr = value.storage.value() if edge_attr is not None: index = index.to(edge_attr.device) edge_attr = edge_attr[index] sparse_sizes = out_store.size()[::-1] # TODO Currently, we set `is_sorted=False`, see: # https://github.com/pyg-team/pytorch_geometric/issues/4346 out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=False, trust_data=True) elif store.is_edge_attr(key): dim = store._parent().__cat_dim__(key, value, store) if perm is None: index = index.to(value.device) out_store[key] = index_select(value, index, dim=dim) else: perm = perm.to(value.device) index = index.to(value.device) out_store[key] = index_select(value, perm[index], dim=dim) return store