Esempio n. 1
0
def filter_edge_store_(store: EdgeStorage,
                       out_store: EdgeStorage,
                       row: Tensor,
                       col: Tensor,
                       index: Tensor,
                       perm: OptTensor = None) -> EdgeStorage:
    # Filters a edge storage object to only hold the edges in `index`,
    # which represents the new graph as denoted by `(row, col)`:
    for key, value in store.items():
        if key == 'edge_index':
            edge_index = torch.stack([row, col], dim=0)
            out_store.edge_index = edge_index.to(value.device)

        elif key == 'adj_t':
            # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
            row = row.to(value.device())
            col = col.to(value.device())
            edge_attr = value.storage.value()
            if edge_attr is not None:
                index = index.to(edge_attr.device)
                edge_attr = edge_attr[index]
            sparse_sizes = out_store.size()[::-1]
            # TODO Currently, we set `is_sorted=False`, see:
            # https://github.com/pyg-team/pytorch_geometric/issues/4346
            out_store.adj_t = SparseTensor(row=col,
                                           col=row,
                                           value=edge_attr,
                                           sparse_sizes=sparse_sizes,
                                           is_sorted=False,
                                           trust_data=True)

        elif store.is_edge_attr(key):
            dim = store._parent().__cat_dim__(key, value, store)
            if perm is None:
                index = index.to(value.device)
                out_store[key] = index_select(value, index, dim=dim)
            else:
                perm = perm.to(value.device)
                index = index.to(value.device)
                out_store[key] = index_select(value, perm[index], dim=dim)

    return store
Esempio n. 2
0
def filter_edge_store_(store: EdgeStorage,
                       out_store: EdgeStorage,
                       row: Tensor,
                       col: Tensor,
                       index: Tensor,
                       perm: OptTensor = None) -> EdgeStorage:
    # Filters a edge storage object to only hold the edges in `index`,
    # which represents the new graph as denoted by `(row, col)`:
    for key, value in store.items():
        if key == 'edge_index':
            edge_index = torch.stack([row, col], dim=0)
            out_store.edge_index = edge_index.to(value.device)

        elif key == 'adj_t':
            # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
            row = row.to(value.device())
            col = col.to(value.device())
            edge_attr = value.storage.value()
            if edge_attr is not None:
                index = index.to(edge_attr.device)
                edge_attr = edge_attr[index]
            sparse_sizes = store.size()[::-1]
            out_store.adj_t = SparseTensor(row=col,
                                           col=row,
                                           value=edge_attr,
                                           sparse_sizes=sparse_sizes,
                                           is_sorted=True)

        elif store.is_edge_attr(key):
            if perm is None:
                index = index.to(value.device)
                out_store[key] = index_select(value, index, dim=0)
            else:
                perm = perm.to(value.device)
                index = index.to(value.device)
                out_store[key] = index_select(value, perm[index], dim=0)

    return store