def reduction(src: SparseTensor, dim: Optional[int] = None, reduce: str = 'sum') -> torch.Tensor: value = src.storage.value() if dim is None: if value is not None: if reduce == 'sum' or reduce == 'add': return value.sum() elif reduce == 'mean': return value.mean() elif reduce == 'min': return value.min() elif reduce == 'max': return value.max() else: raise ValueError else: if reduce == 'sum' or reduce == 'add': return torch.tensor(src.nnz(), dtype=src.dtype(), device=src.device()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.tensor(1, dtype=src.dtype(), device=src.device()) else: raise ValueError else: if dim < 0: dim = src.dim() + dim if dim == 0 and value is not None: col = src.storage.col() return scatter(value, col, dim=0, dim_size=src.size(0)) elif dim == 0 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.colcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(1), dtype=src.dtype()) else: raise ValueError elif dim == 1 and value is not None: return segment_csr(value, src.storage.rowptr(), None, reduce) elif dim == 1 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.rowcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(0), dtype=src.dtype()) else: raise ValueError elif dim > 1 and value is not None: if reduce == 'sum' or reduce == 'add': return value.sum(dim=dim - 1) elif reduce == 'mean': return value.mean(dim=dim - 1) elif reduce == 'min': return value.min(dim=dim - 1)[0] elif reduce == 'max': return value.max(dim=dim - 1)[0] else: raise ValueError else: raise ValueError
def test_cat(device): row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device) mat1 = SparseTensor(row=row, col=col) mat1.fill_cache_() row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device) mat2 = SparseTensor(row=row, col=col) mat2.fill_cache_() out = cat([mat1, mat2], dim=0) assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 0, 0]] assert out.storage.has_row() assert out.storage.has_rowptr() assert out.storage.has_rowcount() assert out.storage.num_cached_keys() == 1 out = cat([mat1, mat2], dim=1) assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1], [0, 0, 0, 1, 0]] assert out.storage.has_row() assert not out.storage.has_rowptr() assert out.storage.num_cached_keys() == 2 out = cat([mat1, mat2], dim=(0, 1)) assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 1, 0]] assert out.storage.has_row() assert out.storage.has_rowptr() assert out.storage.num_cached_keys() == 5 value = torch.randn((mat1.nnz(), 4), device=device) mat1 = mat1.set_value_(value, layout='coo') out = cat([mat1, mat1], dim=-1) assert out.storage.value().size() == (mat1.nnz(), 8) assert out.storage.has_row() assert out.storage.has_rowptr() assert out.storage.num_cached_keys() == 5