示例#1
0
def _test_local_cluster(protocol):
    with LocalCluster(
            protocol=protocol,
            dashboard_address=None,
            n_workers=4,
            threads_per_worker=1,
            processes=True,
    ) as cluster:
        with Client(cluster) as client:
            c = comms.CommsContext(client)
            assert sum(c.run(my_rank)) == sum(range(4))
示例#2
0
def _test_lock_workers(scheduler_address, ranks):
    async def f(_):
        worker = get_worker()
        if hasattr(worker, "running"):
            assert not worker.running
        worker.running = True
        await asyncio.sleep(0.5)
        assert worker.running
        worker.running = False

    with Client(scheduler_address) as client:
        c = comms.CommsContext(client)
        c.run(f,
              workers=[c.worker_addresses[r] for r in ranks],
              lock_workers=True)
示例#3
0
def _test_local_cluster(protocol):
    dask.config.update(
        dask.config.global_config,
        {
            "distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),
        },
        priority="new",
    )

    with LocalCluster(
            protocol=protocol,
            dashboard_address=None,
            n_workers=4,
            threads_per_worker=1,
            processes=True,
    ) as cluster:
        with Client(cluster) as client:
            c = comms.CommsContext(client)
            assert sum(c.run(my_rank, 0)) == sum(range(4))
def _test_local_cluster(protocol):
    dask.config.update(
        dask.config.global_config,
        {
            "ucx": {
                "tcp": True,
                "cuda_copy": True,
            },
        },
        priority="new",
    )

    with LocalCluster(
            protocol=protocol,
            dashboard_address=None,
            n_workers=4,
            threads_per_worker=1,
            processes=True,
    ) as cluster:
        with Client(cluster) as client:
            c = comms.CommsContext(client)
            assert sum(c.run(my_rank, 0)) == sum(range(4))