예제 #1
0
def run_rabit_ops(client, n_workers):
    from test_with_dask import _get_client_workers
    from xgboost.dask import RabitContext, _get_rabit_args
    from xgboost import rabit

    workers = _get_client_workers(client)
    rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
    assert not rabit.is_distributed()
    n_workers_from_dask = len(workers)
    assert n_workers == n_workers_from_dask

    def local_test(worker_id):
        with RabitContext(rabit_args):
            a = 1
            assert rabit.is_distributed()
            a = np.array([a])
            reduced = rabit.allreduce(a, rabit.Op.SUM)
            assert reduced[0] == n_workers

            worker_id = np.array([worker_id])
            reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
            assert reduced == n_workers - 1

            return 1

    futures = client.map(local_test, range(len(workers)), workers=workers)
    results = client.gather(futures)
    assert sum(results) == n_workers
예제 #2
0
    def local_test(worker_id):
        with RabitContext(rabit_args):
            a = 1
            assert rabit.is_distributed()
            a = np.array([a])
            reduced = rabit.allreduce(a, rabit.Op.SUM)
            assert reduced[0] == n_workers

            worker_id = np.array([worker_id])
            reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
            assert reduced == n_workers - 1

            return 1