예제 #1
0
def run_rabit_ops(client, n_workers):
    from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
    from xgboost import rabit

    workers = list(_get_client_workers(client).keys())
    rabit_args = client.sync(_get_rabit_args, workers, client)
    assert not rabit.is_distributed()
    n_workers_from_dask = len(workers)
    assert n_workers == n_workers_from_dask

    def local_test(worker_id):
        with RabitContext(rabit_args):
            a = 1
            assert rabit.is_distributed()
            a = np.array([a])
            reduced = rabit.allreduce(a, rabit.Op.SUM)
            assert reduced[0] == n_workers

            worker_id = np.array([worker_id])
            reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
            assert reduced == n_workers - 1

            return 1

    futures = client.map(local_test, range(len(workers)), workers=workers)
    results = client.gather(futures)
    assert sum(results) == n_workers
예제 #2
0
    def run_quantile(self, name):
        if sys.platform.startswith("win"):
            pytest.skip("Skipping dask tests on Windows")

        exe = None
        for possible_path in {
                './testxgboost', './build/testxgboost', '../build/testxgboost',
                '../gpu-build/testxgboost'
        }:
            if os.path.exists(possible_path):
                exe = possible_path
        assert exe, 'No testxgboost executable found.'
        test = "--gtest_filter=GPUQuantile." + name

        def runit(worker_addr, rabit_args):
            port = None
            # setup environment for running the c++ part.
            for arg in rabit_args:
                if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
                    port = arg.decode('utf-8')
            port = port.split('=')
            env = os.environ.copy()
            env[port[0]] = port[1]
            return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)

        with LocalCUDACluster() as cluster:
            with Client(cluster) as client:
                workers = list(dxgb._get_client_workers(client).keys())
                rabit_args = dxgb._get_rabit_args(workers, client)
                futures = client.map(runit,
                                     workers,
                                     pure=False,
                                     workers=workers,
                                     rabit_args=rabit_args)
                results = client.gather(futures)
                for ret in results:
                    msg = ret.stdout.decode('utf-8')
                    assert msg.find('1 test from GPUQuantile') != -1, msg
                    assert ret.returncode == 0, msg