コード例 #1
0
ファイル: test_initialize.py プロジェクト: lenskit/csr
def test_empty(nrows, ncols):
    csr = CSR.empty(nrows, ncols)
    assert csr.nrows == nrows
    assert csr.ncols == ncols
    assert csr.nnz == 0
    assert all(csr.rowptrs == 0)
    assert len(csr.rowptrs) == nrows + 1
    assert len(csr.colinds) == 0
コード例 #2
0
    def _compute_similarities(self, rmat):
        trmat = rmat.transpose()
        nitems = trmat.nrows
        m_nbrs = self.save_nbrs
        if m_nbrs is None or m_nbrs < 0:
            m_nbrs = 0

        bounds = _make_blocks(nitems, 1000)
        _logger.info('[%s] splitting %d items (%d ratings) into %d blocks',
                     self._timer, nitems, trmat.nnz, len(bounds))
        blocks = [trmat.subset_rows(sp, ep) for (sp, ep) in bounds]

        _logger.info('[%s] computing similarities', self._timer)
        ptrs = List(bounds)
        nbs = List(blocks)
        if not nbs:
            # oops, this is the bad place
            # in non-JIT node, List doesn't actually make the list
            nbs = blocks
            ptrs = bounds
        s_blocks = _sim_blocks(trmat, nbs, ptrs, self.min_sim, m_nbrs)

        nnz = sum(b.nnz for b in s_blocks)
        tot_rows = sum(b.nrows for b in s_blocks)
        _logger.info('[%s] computed %d similarities for %d items in %d blocks',
                     self._timer, nnz, tot_rows, len(s_blocks))
        row_nnzs = np.concatenate([b.row_nnzs() for b in s_blocks])
        assert len(row_nnzs) == nitems, \
            'only have {} rows for {} items'.format(len(row_nnzs), nitems)

        smat = CSR.empty(nitems, nitems, row_nnzs)
        start = 0
        for bi, b in enumerate(s_blocks):
            bnr = b.nrows
            end = start + bnr
            v_sp = smat.rowptrs[start]
            v_ep = smat.rowptrs[end]
            _logger.debug('block %d (%d:%d) has %d entries, storing in %d:%d',
                          bi, start, end, b.nnz, v_sp, v_ep)
            smat.colinds[v_sp:v_ep] = b.colinds
            smat.values[v_sp:v_ep] = b.values
            start = end

        _logger.info('[%s] sorting similarity matrix with %d entries',
                     self._timer, smat.nnz)
        _sort_nbrs(smat)

        return smat
コード例 #3
0
ファイル: conftest.py プロジェクト: lenskit/csr
def kernel(request):
    """
    Fixture for variable CSR kernels.  This fixture is parameterized, so if you
    write a test function with a parameter ``kernel`` as its first parameter, it
    will be called once for each kernel under active test.
    """
    if request.param in DISABLED_KERNELS:
        skip(f'kernel {request.param} is disabled')

    with use_kernel(request.param):
        k = get_kernel()
        # warm-up the kernel
        m = CSR.empty(1, 1)
        h = k.to_handle(m)
        k.release_handle(h)
        del h, m
        yield k
コード例 #4
0
ファイル: test_initialize.py プロジェクト: lenskit/csr
def test_empty_csr(data, nrows, ncols, vdt):
    sizes = data.draw(nph.arrays(np.int32, nrows, elements=st.integers(0, ncols)))
    csr = CSR.empty(nrows, ncols, sizes, values=vdt)
    assert csr.nrows == nrows
    assert csr.ncols == ncols
    assert csr.nnz == np.sum(sizes)
    assert len(csr.rowptrs) == nrows + 1
    assert csr.rowptrs.dtype == np.int32
    assert all(csr.row_nnzs() == sizes)
    assert len(csr.colinds) == np.sum(sizes)
    if vdt:
        assert csr.values is not None
        if vdt is not True:
            assert csr.values.dtype == vdt
        assert csr.values.shape == (csr.nnz,)
    else:
        assert csr.values is None
コード例 #5
0
ファイル: test_initialize.py プロジェクト: lenskit/csr
def test_large_empty():
    # 10M * 250 = 2.5B >= INT_MAX
    nrows = 10000000
    ncols = 500
    nnz = nrows * 250

    row_nnzs = np.full(nrows, 250, dtype='i4')

    try:
        csr = CSR.empty(nrows, ncols, row_nnzs=row_nnzs, values=False)
    except MemoryError:
        pytest.skip('insufficient memory')

    assert csr.nrows == nrows
    assert csr.ncols == ncols
    assert csr.nnz == nnz
    assert csr.rowptrs.dtype == np.dtype('i8')
    assert np.all(csr.rowptrs >= 0)
    assert np.all(np.diff(csr.rowptrs) == 250)