コード例 #1
0
def _run_in_parallel(fn: Callable, conn: csr_matrix, **kwargs) -> Any:
    fname = fn.__name__
    if fname == "_run_stochastic":
        if not _HAS_JAX:
            raise RuntimeError(
                "Install `jax` and `jaxlib` as `pip install jax jaxlib`.")
        ixs = np.argsort(np.array((conn != 0).sum(1)).ravel())[::-1]
    else:
        ixs = np.arange(conn.shape[0])
        np.random.shuffle(ixs)

    unit = ("sample" if (fname == "_run_mc") and kwargs.get("n_samples", 1) > 1
            else "cell")
    kwargs["indices"] = conn.indices
    kwargs["indptr"] = conn.indptr

    return parallelize(
        fn,
        ixs,
        as_array=False,
        extractor=lambda res: _reconstruct_one(np.concatenate(res, axis=-1),
                                               conn, ixs),
        unit=unit,
        **_filter_kwargs(parallelize, **kwargs),
    )(**_filter_kwargs(fn, **kwargs))
コード例 #2
0
ファイル: test_utils.py プロジェクト: theislab/cellrank
    def test_reconstruct_one(self, seed: int, shuffle: bool):
        m1 = random(100, 10, random_state=seed, density=0.5, format="csr")
        m1[:, 0] = 0.1
        m1 /= m1.sum(1)
        m1 = csr_matrix(m1)

        m2_data = np.random.normal(size=(m1.nnz))
        m2 = csr_matrix((m2_data, m1.indices, m1.indptr))

        if shuffle:
            ixs = np.arange(100)
            np.random.shuffle(ixs)
            data = np.c_[m1[ixs, :].data, m2[ixs, :].data].T
        else:
            ixs = None
            data = np.c_[m1.data, m2.data].T

        r1, r2 = _reconstruct_one(data, m1, ixs=ixs)

        np.testing.assert_array_equal(r1.A, m1.A)
        np.testing.assert_array_equal(r2.A, m2.A)