def _test_local_cluster(protocol): with LocalCluster( protocol=protocol, dashboard_address=None, n_workers=4, threads_per_worker=1, processes=True, ) as cluster: with Client(cluster) as client: c = comms.CommsContext(client) assert sum(c.run(my_rank)) == sum(range(4))
def _test_lock_workers(scheduler_address, ranks): async def f(_): worker = get_worker() if hasattr(worker, "running"): assert not worker.running worker.running = True await asyncio.sleep(0.5) assert worker.running worker.running = False with Client(scheduler_address) as client: c = comms.CommsContext(client) c.run(f, workers=[c.worker_addresses[r] for r in ranks], lock_workers=True)
def _test_local_cluster(protocol): dask.config.update( dask.config.global_config, { "distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True), }, priority="new", ) with LocalCluster( protocol=protocol, dashboard_address=None, n_workers=4, threads_per_worker=1, processes=True, ) as cluster: with Client(cluster) as client: c = comms.CommsContext(client) assert sum(c.run(my_rank, 0)) == sum(range(4))
def _test_local_cluster(protocol): dask.config.update( dask.config.global_config, { "ucx": { "tcp": True, "cuda_copy": True, }, }, priority="new", ) with LocalCluster( protocol=protocol, dashboard_address=None, n_workers=4, threads_per_worker=1, processes=True, ) as cluster: with Client(cluster) as client: c = comms.CommsContext(client) assert sum(c.run(my_rank, 0)) == sum(range(4))