def test_get_context_using_python3_posix(): """ get_context() respects configuration. If default context is changed this test will need to change too. """ assert get_context() is multiprocessing.get_context(None) with dask.config.set({"multiprocessing.context": "forkserver"}): assert get_context() is multiprocessing.get_context("forkserver") with dask.config.set({"multiprocessing.context": "spawn"}): assert get_context() is multiprocessing.get_context("spawn")
def test_register_backend_entrypoint(tmp_path): # Create special sizeof implementation for a dummy class (tmp_path / "impl_sizeof.py").write_bytes( b"def sizeof_plugin(sizeof):\n" b' print("REG")\n' b' @sizeof.register_lazy("class_impl")\n' b" def register_impl():\n" b" import class_impl\n" b" @sizeof.register(class_impl.Impl)\n" b" def sizeof_impl(obj):\n" b" return obj.size \n" ) # Define dummy class that possesses a size attribute (tmp_path / "class_impl.py").write_bytes( b"class Impl:\n def __init__(self, size):\n self.size = size" ) dist_info = tmp_path / "impl_sizeof-0.0.0.dist-info" dist_info.mkdir() (dist_info / "entry_points.txt").write_bytes( b"[dask.sizeof]\nimpl = impl_sizeof:sizeof_plugin\n" ) with get_context().Pool(1) as pool: assert ( pool.apply(_get_sizeof_on_path, args=(tmp_path, 3_14159265)) == 3_14159265 ) pool.join()
def start(self) -> Iterator: """Context manager for initializing execution.""" # import dask here to reduce prefect import times import dask.config from dask.callbacks import Callback from dask.system import CPU_COUNT class PrefectCallback(Callback): def __init__(self): # type: ignore self.cache = {} def _start(self, dsk): # type: ignore overlap = set(dsk) & set(self.cache) for key in overlap: dsk[key] = self.cache[key] def _posttask(self, key, value, dsk, state, id): # type: ignore self.cache[key] = value with PrefectCallback(), dask.config.set(**self.dask_config): if self.scheduler == "synchronous": self._pool = None else: num_workers = dask.config.get("num_workers", None) or CPU_COUNT if self.scheduler == "threads": from multiprocessing.pool import ThreadPool self._pool = ThreadPool(num_workers) else: from dask.multiprocessing import get_context context = get_context() self._pool = context.Pool( num_workers, initializer=_multiprocessing_pool_initializer ) try: exiting_early = False yield except BaseException: exiting_early = True raise finally: if self._pool is not None: if exiting_early: self._interrupt_pool() else: self._pool.close() self._pool.join() self._pool = None
def test_get_context_always_default(): """ On Python 2/Windows, get_context() always returns same context.""" assert get_context() is multiprocessing with pytest.warns(UserWarning): with dask.config.set({"multiprocessing.context": "forkserver"}): assert get_context() is multiprocessing