コード例 #1
0
def reduction(src: SparseTensor, dim: Optional[int] = None,
              reduce: str = 'sum') -> torch.Tensor:
    value = src.storage.value()

    if dim is None:
        if value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum()
            elif reduce == 'mean':
                return value.mean()
            elif reduce == 'min':
                return value.min()
            elif reduce == 'max':
                return value.max()
            else:
                raise ValueError
        else:
            if reduce == 'sum' or reduce == 'add':
                return torch.tensor(src.nnz(), dtype=src.dtype(),
                                    device=src.device())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.tensor(1, dtype=src.dtype(), device=src.device())
            else:
                raise ValueError
    else:
        if dim < 0:
            dim = src.dim() + dim

        if dim == 0 and value is not None:
            col = src.storage.col()
            return scatter(value, col, dim=0, dim_size=src.size(0))
        elif dim == 0 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.colcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(1), dtype=src.dtype())
            else:
                raise ValueError
        elif dim == 1 and value is not None:
            return segment_csr(value, src.storage.rowptr(), None, reduce)
        elif dim == 1 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.rowcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(0), dtype=src.dtype())
            else:
                raise ValueError
        elif dim > 1 and value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum(dim=dim - 1)
            elif reduce == 'mean':
                return value.mean(dim=dim - 1)
            elif reduce == 'min':
                return value.min(dim=dim - 1)[0]
            elif reduce == 'max':
                return value.max(dim=dim - 1)[0]
            else:
                raise ValueError
        else:
            raise ValueError
コード例 #2
0
ファイル: test_cat.py プロジェクト: zeta1999/pytorch_sparse
def test_cat(device):
    row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
    mat1 = SparseTensor(row=row, col=col)
    mat1.fill_cache_()

    row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
    mat2 = SparseTensor(row=row, col=col)
    mat2.fill_cache_()

    out = cat([mat1, mat2], dim=0)
    assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
                                       [0, 1, 0], [1, 0, 0]]
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.has_rowcount()
    assert out.storage.num_cached_keys() == 1

    out = cat([mat1, mat2], dim=1)
    assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
                                       [0, 0, 0, 1, 0]]
    assert out.storage.has_row()
    assert not out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 2

    out = cat([mat1, mat2], dim=(0, 1))
    assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
                                       [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
                                       [0, 0, 0, 1, 0]]
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 5

    value = torch.randn((mat1.nnz(), 4), device=device)
    mat1 = mat1.set_value_(value, layout='coo')
    out = cat([mat1, mat1], dim=-1)
    assert out.storage.value().size() == (mat1.nnz(), 8)
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 5