def spmm(
    rows: torch.Tensor,
    cols: torch.Tensor,
    vals: torch.Tensor,
    size: torch.Size,
    mat: torch.Tensor,
    return_num_nonzero: bool = False,
    cuda_spmm_alg: int = 1,
):

    assert len(rows) == len(cols), "Invalid length"
    assert len(rows) == len(vals), "Invalid length"
    assert vals.dtype == mat.dtype, "dtype mismatch"
    assert vals.device == mat.device, "device mismatch"
    if mat.is_cuda:
        assert (
            rows.is_cuda and cols.is_cuda and vals.is_cuda
        ), "All inputs must be on cuda"
        rows = rows.int()
        cols = cols.int()
        result, num_nonzero = MEB.coo_spmm_int32(
            rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg, return_num_nonzero
        )

        # WARNING: TODO: not sorting the vals. Should not be used for generic SPMM
        # coosort only supports int32
        # return MEB.coo_spmm_int64(
        #     rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg
        # )
    else:
        COO = torch.stack(
            (rows, cols),
            0,
        ).long()
        torchSparseTensor = None
        if vals.dtype == torch.float64:
            torchSparseTensor = torch.sparse.DoubleTensor
        elif vals.dtype == torch.float32:
            torchSparseTensor = torch.sparse.FloatTensor
        else:
            raise ValueError(f"Unsupported data type: {vals.dtype}")

        sp = torchSparseTensor(COO, vals, size)
        result = sp.matmul(mat)
        if return_num_nonzero:
            num_nonzero = sp.matmul(torch.ones((size[1], 1), dtype=vals.dtype))

    if return_num_nonzero:
        return result, num_nonzero
    else:
        return result
def spmm(
    rows: torch.Tensor,
    cols: torch.Tensor,
    vals: torch.Tensor,
    size: torch.Size,
    mat: torch.Tensor,
    cuda_spmm_alg: int = 1,
):

    assert len(rows) == len(cols), "Invalid length"
    assert len(rows) == len(vals), "Invalid length"
    assert vals.dtype == mat.dtype, "dtype mismatch"
    assert vals.device == mat.device, "device mismatch"
    if mat.is_cuda:
        assert rows.is_cuda and cols.is_cuda and vals.is_cuda
        rows = rows.int()
        cols = cols.int()
        return MEB.coo_spmm_int32(rows, cols, vals, size[0], size[1], mat,
                                  cuda_spmm_alg)

        # WARNING: TODO: not sorting the vals. Should not be used for generic SPMM
        # coosort only supports int32
        # return MEB.coo_spmm_int64(
        #     rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg
        # )
    else:
        COO = torch.stack(
            (rows, cols),
            0,
        ).long()
        torchSparseTensor = None
        if vals.dtype == torch.float64:
            torchSparseTensor = torch.sparse.DoubleTensor
        elif vals.dtype == torch.float32:
            torchSparseTensor = torch.sparse.FloatTensor
        else:
            raise ValueError(f"Unsupported data type: {vals.dtype}")

        sp = torchSparseTensor(COO, vals, size)
        return sp.matmul(mat)