Exemple #1
0
def reduction(src: SparseTensor, dim: Optional[int] = None,
              reduce: str = 'sum') -> torch.Tensor:
    value = src.storage.value()

    if dim is None:
        if value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum()
            elif reduce == 'mean':
                return value.mean()
            elif reduce == 'min':
                return value.min()
            elif reduce == 'max':
                return value.max()
            else:
                raise ValueError
        else:
            if reduce == 'sum' or reduce == 'add':
                return torch.tensor(src.nnz(), dtype=src.dtype(),
                                    device=src.device())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.tensor(1, dtype=src.dtype(), device=src.device())
            else:
                raise ValueError
    else:
        if dim < 0:
            dim = src.dim() + dim

        if dim == 0 and value is not None:
            col = src.storage.col()
            return scatter(value, col, dim=0, dim_size=src.size(0))
        elif dim == 0 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.colcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(1), dtype=src.dtype())
            else:
                raise ValueError
        elif dim == 1 and value is not None:
            return segment_csr(value, src.storage.rowptr(), None, reduce)
        elif dim == 1 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.rowcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(0), dtype=src.dtype())
            else:
                raise ValueError
        elif dim > 1 and value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum(dim=dim - 1)
            elif reduce == 'mean':
                return value.mean(dim=dim - 1)
            elif reduce == 'min':
                return value.min(dim=dim - 1)[0]
            elif reduce == 'max':
                return value.max(dim=dim - 1)[0]
            else:
                raise ValueError
        else:
            raise ValueError
Exemple #2
0
def partition(
        src: SparseTensor,
        num_parts: int,
        recursive: bool = False,
        weighted=False) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:

    assert num_parts >= 1
    if num_parts == 1:
        partptr = torch.tensor([0, src.size(0)], device=src.device())
        perm = torch.arange(src.size(0), device=src.device())
        return src, partptr, perm

    rowptr, col, value = src.csr()
    rowptr, col = rowptr.cpu(), col.cpu()

    if value is not None and weighted:
        assert value.numel() == col.numel()
        value = value.view(-1).detach().cpu()
        if value.is_floating_point():
            value = weight2metis(value)
    else:
        value = None

    cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
                                               recursive)
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
Exemple #3
0
def set_diag(src: SparseTensor,
             values: Optional[torch.Tensor] = None,
             k: int = 0) -> SparseTensor:
    src = remove_diag(src, k=k)
    row, col, value = src.coo()

    mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
                                                src.size(1), k)
    inv_mask = ~mask

    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=row.device)

    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag

    new_col = col.new_empty(mask.size(0))
    new_col[mask] = col
    new_col[inv_mask] = diag.add_(k)

    new_value: Optional[torch.Tensor] = None
    if value is not None:
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
        new_value[mask] = value
        if values is not None:
            new_value[inv_mask] = values
        else:
            new_value[inv_mask] = torch.ones((num_diag, ),
                                             dtype=value.dtype,
                                             device=value.device)

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

    storage = SparseStorage(row=new_row,
                            rowptr=None,
                            col=new_col,
                            value=new_value,
                            sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount,
                            colptr=None,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)
    return src.from_storage(storage)
Exemple #4
0
def get_diag(src: SparseTensor) -> Tensor:
    row, col, value = src.coo()

    if value is None:
        value = torch.ones(row.size(0))

    sizes = list(value.size())
    sizes[0] = min(src.size(0), src.size(1))

    out = value.new_zeros(sizes)

    mask = row == col
    out[row[mask]] = value[mask]
    return out
Exemple #5
0
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise.
        other = gather_csr(other.squeeze(1), rowptr)
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise.
        other = other.squeeze(0)[col]
    else:
        raise ValueError(
            f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
            f'(1, {src.size(1)}, ...), but got size {other.size()}.')

    if value is not None:
        value = value.add_(other.to(value.dtype))
    else:
        value = other.add_(1)
    return src.set_value_(value, layout='coo')
Exemple #6
0
def correctness(dataset):
    group, name = dataset
    mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
    col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
    mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
    mat.fill_cache_()
    mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()

    for size in sizes:
        try:
            x = torch.randn((mat.size(1), size), device=args.device)

            out1 = mat @ x
            out2 = mat_pytorch @ x

            assert torch.allclose(out1, out2, atol=1e-4)

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
Exemple #7
0
def timing(dataset):
    group, name = dataset
    mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
    col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
    mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
    mat.fill_cache_()
    mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
    mat_scipy = mat.to_scipy(layout='csr')

    def scatter(x):
        return scatter_add(x[col], row, dim=0, dim_size=mat_scipy.shape[0])

    def spmm_scipy(x):
        if x.is_cuda:
            raise RuntimeError('out of memory')
        return mat_scipy @ x

    def spmm_pytorch(x):
        return mat_pytorch @ x

    def spmm(x):
        return mat @ x

    t1, t2, t3, t4 = [], [], [], []

    for size in sizes:
        try:
            x = torch.randn((mat.size(1), size), device=args.device)

            t1 += [time_func(scatter, x)]
            t2 += [time_func(spmm_scipy, x)]
            t3 += [time_func(spmm_pytorch, x)]
            t4 += [time_func(spmm, x)]

            del x

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
            for t in (t1, t2, t3, t4):
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {mat.avg_row_length():.2f}):')
    print('\t'.join(['            '] + [f'{size:>5}' for size in sizes]))
    print('\t'.join([bold('Scatter     ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('SPMM SciPy  ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('SPMM PyTorch')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('SPMM Own    ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print()
Exemple #8
0
def narrow(src: SparseTensor, dim: int, start: int,
           length: int) -> SparseTensor:
    if dim < 0:
        dim = src.dim() + dim

    if start < 0:
        start = src.size(dim) + start

    if dim == 0:
        rowptr, col, value = src.csr()

        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]

        row = src.storage._row
        if row is not None:
            row = row.narrow(0, row_start, row_length) - start

        col = col.narrow(0, row_start, row_length)

        if value is not None:
            value = value.narrow(0, row_start, row_length)

        sparse_sizes = (length, src.sparse_size(1))

        rowcount = src.storage._rowcount
        if rowcount is not None:
            rowcount = rowcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row,
                                rowptr=rowptr,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=rowcount,
                                colptr=None,
                                colcount=None,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        # This is faster than accessing `csc()` contrary to the `dim=0` case.
        row, col, value = src.coo()
        mask = (col >= start) & (col < start + length)

        row = row[mask]
        col = col[mask] - start

        if value is not None:
            value = value[mask]

        sparse_sizes = (src.sparse_size(0), length)

        colptr = src.storage._colptr
        if colptr is not None:
            colptr = colptr.narrow(0, start=start, length=length + 1)
            colptr = colptr - colptr[0]

        colcount = src.storage._colcount
        if colcount is not None:
            colcount = colcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=None,
                                colptr=colptr,
                                colcount=colcount,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.narrow(dim - 1, start, length),
                                 layout='coo')
        else:
            raise ValueError