def _run_in_parallel(fn: Callable, conn: csr_matrix, **kwargs) -> Any: fname = fn.__name__ if fname == "_run_stochastic": if not _HAS_JAX: raise RuntimeError( "Install `jax` and `jaxlib` as `pip install jax jaxlib`.") ixs = np.argsort(np.array((conn != 0).sum(1)).ravel())[::-1] else: ixs = np.arange(conn.shape[0]) np.random.shuffle(ixs) unit = ("sample" if (fname == "_run_mc") and kwargs.get("n_samples", 1) > 1 else "cell") kwargs["indices"] = conn.indices kwargs["indptr"] = conn.indptr return parallelize( fn, ixs, as_array=False, extractor=lambda res: _reconstruct_one(np.concatenate(res, axis=-1), conn, ixs), unit=unit, **_filter_kwargs(parallelize, **kwargs), )(**_filter_kwargs(fn, **kwargs))
def test_reconstruct_one(self, seed: int, shuffle: bool): m1 = random(100, 10, random_state=seed, density=0.5, format="csr") m1[:, 0] = 0.1 m1 /= m1.sum(1) m1 = csr_matrix(m1) m2_data = np.random.normal(size=(m1.nnz)) m2 = csr_matrix((m2_data, m1.indices, m1.indptr)) if shuffle: ixs = np.arange(100) np.random.shuffle(ixs) data = np.c_[m1[ixs, :].data, m2[ixs, :].data].T else: ixs = None data = np.c_[m1.data, m2.data].T r1, r2 = _reconstruct_one(data, m1, ixs=ixs) np.testing.assert_array_equal(r1.A, m1.A) np.testing.assert_array_equal(r2.A, m2.A)