Beispiel #1
0
def test_lock_sync(client):
    def f(x):
        with Lock('x') as lock:
            client = get_client()
            assert client.get_metadata('locked') is False
            client.set_metadata('locked', True)
            sleep(0.05)
            assert client.get_metadata('locked') is True
            client.set_metadata('locked', False)

    client.set_metadata('locked', False)
    futures = client.map(f, range(10))
    client.gather(futures)
Beispiel #2
0
def test_lock_sync(client):
    def f(x):
        with Lock('x') as lock:
            client = get_client()
            assert client.get_metadata('locked') is False
            client.set_metadata('locked', True)
            sleep(0.05)
            assert client.get_metadata('locked') is True
            client.set_metadata('locked', False)

    client.set_metadata('locked', False)
    futures = client.map(f, range(10))
    client.gather(futures)
def test_text_progressbar(capsys, client):
    futures = client.map(inc, range(10))
    p = TextProgressBar(futures, interval=0.01, complete=True)
    client.gather(futures)

    start = time()
    while p.status != 'finished':
        sleep(0.01)
        assert time() - start < 5

    check_bar_completed(capsys)
    assert p._last_response == {'all': 10,
                                'remaining': 0,
                                'status': 'finished'}
    assert p.comm.closed()
def test_worker_dies():
    with cluster(config={
            "distributed.scheduler.locks.lease-timeout": "0.1s",
    }) as (scheduler, workers):
        with Client(scheduler["address"]) as client:
            sem = Semaphore(name="x", max_leases=1)

            def f(x, sem, kill_address):
                with sem:
                    from distributed.worker import get_worker

                    worker = get_worker()
                    if worker.address == kill_address:
                        import os

                        os.kill(os.getpid(), 15)
                    return x

            futures = client.map(f,
                                 range(10),
                                 sem=sem,
                                 kill_address=workers[0]["address"])
            results = client.gather(futures)

            assert sorted(results) == list(range(10))
Beispiel #5
0
    def test_data_initialization(self) -> None:
        '''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
        generate unnecessary copies of data.

        '''
        with LocalCluster(n_workers=2) as cluster:
            with Client(cluster) as client:
                X, y, _ = generate_array()
                n_partitions = X.npartitions
                m = xgb.dask.DaskDMatrix(client, X, y)
                workers = list(_get_client_workers(client).keys())
                rabit_args = client.sync(xgb.dask._get_rabit_args,
                                         len(workers), client)
                n_workers = len(workers)

                def worker_fn(worker_addr: str, data_ref: Dict) -> None:
                    with xgb.dask.RabitContext(rabit_args):
                        local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
                            **data_ref)
                        total = np.array([local_dtrain.num_row()])
                        total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
                        assert total[0] == kRows

                futures = []
                for i in range(len(workers)):
                    futures.append(
                        client.submit(worker_fn,
                                      workers[i],
                                      m.create_fn_args(workers[i]),
                                      pure=False,
                                      workers=[workers[i]]))
                client.gather(futures)

                has_what = client.has_what()
                cnt = 0
                data = set()
                for k, v in has_what.items():
                    for d in v:
                        cnt += 1
                        data.add(d)

                assert len(data) == cnt
                # Subtract the on disk resource from each worker
                assert cnt - n_workers == n_partitions
Beispiel #6
0
def test_event_sync(client):
    # Assert that we call the client.sync correctly
    def wait_for_it_failing(x):
        event = Event("x")

        # Event is not set in another task so far
        assert not event.wait(timeout=0.05)
        assert not event.is_set()

    def wait_for_it_ok(x):
        event = Event("x")

        # Event is set in another task
        assert event.wait(timeout=0.5)
        assert event.is_set()

    def set_it():
        event = Event("x")
        event.set()

    wait_futures = client.map(wait_for_it_failing, range(10))
    client.gather(wait_futures)

    set_future = client.submit(set_it)
    client.gather(set_future)

    wait_futures = client.map(wait_for_it_ok, range(10))
    client.gather(wait_futures)
Beispiel #7
0
    def test_data_initialization(self):
        '''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
        generate unnecessary copies of data.

        '''
        with LocalCluster(n_workers=2) as cluster:
            with Client(cluster) as client:
                X, y = generate_array()
                n_partitions = X.npartitions
                m = xgb.dask.DaskDMatrix(client, X, y)
                workers = list(xgb.dask._get_client_workers(client).keys())
                rabit_args = client.sync(xgb.dask._get_rabit_args, workers,
                                         client)
                n_workers = len(workers)

                def worker_fn(worker_addr, data_ref):
                    with xgb.dask.RabitContext(rabit_args):
                        local_dtrain = xgb.dask._dmatrix_from_worker_map(
                            **data_ref)
                        assert local_dtrain.num_row() == kRows / n_workers

                futures = client.map(worker_fn,
                                     workers,
                                     [m.create_fn_args()] * len(workers),
                                     pure=False,
                                     workers=workers)
                client.gather(futures)

                has_what = client.has_what()
                cnt = 0
                data = set()
                for k, v in has_what.items():
                    for d in v:
                        cnt += 1
                        data.add(d)

                assert len(data) == cnt
                # Subtract the on disk resource from each worker
                assert cnt - n_workers == n_partitions
def test_threadpoolworkers_pick_correct_ioloop(cleanup):
    # gh4057

    # About picking appropriate values for the various timings
    # * Sleep time in `access_limited` impacts test runtime but is arbitrary
    # * `lease-timeout` should be smaller than the sleep time. This is what the
    #   test builds on. assuming the leases cannot be refreshed, e.g. wrong
    #   event loop picked / PeriodicCallback never scheduled, the semaphore
    #   would become oversubscribed and the len(protected_resources) becomes
    #   non zero. This should also trigger a log message about "unknown leases"
    #   and fails the test.
    # * `lease-validation-interval` interval should be the smallest quantity.
    #   How often leases are checked for staleness is hard coded atm and a fifth
    #   of the `lease-timeout`. Accounting for this and some jitter, this should
    #   be sufficiently small to ensure smooth operation.

    with dask.config.set({
            "distributed.scheduler.locks.lease-validation-interval":
            0.01,
            "distributed.scheduler.locks.lease-timeout":
            0.1,
    }):
        with Client(processes=False, threads_per_worker=4) as client:
            sem = Semaphore(max_leases=1, name="database")
            protected_resource = []

            def access_limited(val, sem):
                import time

                with sem:
                    assert len(protected_resource) == 0
                    protected_resource.append(val)
                    # Interact with the DB
                    time.sleep(0.2)
                    protected_resource.remove(val)

            client.gather(client.map(access_limited, range(10), sem=sem))
Beispiel #9
0
def test_threadpoolworkers_pick_correct_ioloop(cleanup):
    # gh4057

    with dask.config.set({
            "distributed.scheduler.locks.lease-validation-interval":
            0.01,
            "distributed.scheduler.locks.lease-timeout":
            0.05,
    }):
        with Client(processes=False, threads_per_worker=4) as client:
            sem = Semaphore(max_leases=1, name="database")
            protected_ressource = []

            def access_limited(val, sem):
                import time

                with sem:
                    assert len(protected_ressource) == 0
                    protected_ressource.append(val)
                    # Interact with the DB
                    time.sleep(0.1)
                    protected_ressource.remove(val)

            client.gather(client.map(access_limited, range(10), sem=sem))
Beispiel #10
0
    def run_quantile(self, name: str) -> None:
        if sys.platform.startswith("win"):
            pytest.skip("Skipping dask tests on Windows")

        exe: Optional[str] = None
        for possible_path in {
                './testxgboost', './build/testxgboost', '../build/testxgboost',
                '../cpu-build/testxgboost'
        }:
            if os.path.exists(possible_path):
                exe = possible_path
        if exe is None:
            return

        test = "--gtest_filter=Quantile." + name

        def runit(worker_addr: str,
                  rabit_args: List[bytes]) -> subprocess.CompletedProcess:
            port_env = ''
            # setup environment for running the c++ part.
            for arg in rabit_args:
                if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
                    port_env = arg.decode('utf-8')
            port = port_env.split('=')
            env = os.environ.copy()
            env[port[0]] = port[1]
            return subprocess.run([str(exe), test],
                                  env=env,
                                  capture_output=True)

        with LocalCluster(n_workers=4) as cluster:
            with Client(cluster) as client:
                workers = list(_get_client_workers(client).keys())
                rabit_args = client.sync(xgb.dask._get_rabit_args,
                                         len(workers), client)
                futures = client.map(runit,
                                     workers,
                                     pure=False,
                                     workers=workers,
                                     rabit_args=rabit_args)
                results = client.gather(futures)

                for ret in results:
                    msg = ret.stdout.decode('utf-8')
                    assert msg.find('1 test from Quantile') != -1, msg
                    assert ret.returncode == 0, msg