예제 #1
0
def test_Future_knows_status_immediately(c, s, a, b):
    x = yield c.scatter(123)
    q = yield Queue('q')
    yield q.put(x)

    c2 = yield Client(s.address, asynchronous=True)
    q2 = yield Queue('q', client=c2)
    future = yield q2.get()
    assert future.status == 'finished'

    x = c.submit(div, 1, 0)
    yield wait(x)
    yield q.put(x)

    future2 = yield q2.get()
    assert future2.status == 'error'
    with pytest.raises(Exception):
        yield future2

    start = time()
    while True:  # we learn about the true error eventually
        try:
            yield future2
        except ZeroDivisionError:
            break
        except Exception:
            assert time() < start + 5
            yield gen.sleep(0.05)

    yield c2.close()
예제 #2
0
def test_queue(c, s, a, b):
    x = yield Queue('x')
    y = yield Queue('y')
    xx = yield Queue('x')
    assert x.client is c

    future = c.submit(inc, 1)

    yield x.put(future)
    yield y.put(future)
    future2 = yield xx.get()
    assert future.key == future2.key

    with pytest.raises(gen.TimeoutError):
        yield x.get(timeout=0.1)

    del future, future2

    yield gen.sleep(0.1)
    assert s.tasks  # future still present in y's queue
    yield y.get()  # burn future

    start = time()
    while s.tasks:
        yield gen.sleep(0.01)
        assert time() < start + 5
예제 #3
0
def test_channel_scheduler(c, s, a, b):
    chan = c.channel('chan', maxlen=5)

    x = c.submit(inc, 1)
    key = x.key
    chan.append(x)
    del x

    while not len(chan):
        yield gen.sleep(0.01)

    assert 'streaming-chan' in s.who_wants[key]
    assert s.wants_what['streaming-chan'] == {key}

    while len(s.who_wants[key]) < 2:
        yield gen.sleep(0.01)

    assert s.wants_what[c.id] == {key}

    for i in range(10):
        chan.append(c.submit(inc, i))

    start = time()
    while True:
        if len(chan) == len(s.task_state) == 5:
            break
        else:
            assert time() < start + 2
            yield gen.sleep(0.01)

    results = yield c._gather(list(chan.futures))
    assert results == [6, 7, 8, 9, 10]
예제 #4
0
def test_timeout_sync(client):
    v = Variable('v')
    start = time()
    with pytest.raises(gen.TimeoutError):
        v.get(timeout=0.1)
    stop = time()
    assert 0.1 < stop - start < 2.0
예제 #5
0
def test_dont_steal_long_running_tasks(c, s, a, b):
    def long(delay):
        with worker_client() as c:
            sleep(delay)

    yield c.submit(long, 0.1)  # learn duration
    yield c.submit(inc, 1)  # learn duration

    long_tasks = c.map(long, [0.5, 0.6], workers=a.address,
                       allow_other_workers=True)
    while sum(map(len, s.processing.values())) < 2:  # let them start
        yield gen.sleep(0.01)

    start = time()
    while any(t.key in s.extensions['stealing'].key_stealable for t in long_tasks):
        yield gen.sleep(0.01)
        assert time() < start + 1

    na = len(a.executing)
    nb = len(b.executing)

    incs = c.map(inc, range(100), workers=a.address, allow_other_workers=True)

    yield gen.sleep(0.2)

    yield wait(long_tasks)

    for t in long_tasks:
        assert (sum(log[1] == 'executing' for log in a.story(t)) +
                sum(log[1] == 'executing' for log in b.story(t))) <= 1
예제 #6
0
def test_failed_worker_without_warning(c, s, a, b):
    L = c.map(inc, range(10))
    yield _wait(L)

    original_process = a.process
    a.process.terminate()
    start = time()
    while a.process is original_process and not isalive(a.process):
        yield gen.sleep(0.01)
        assert time() - start < 10

    yield gen.sleep(0.5)

    start = time()
    while len(s.ncores) < 2:
        yield gen.sleep(0.01)
        assert time() - start < 10

    yield _wait(L)

    L2 = c.map(inc, range(10, 20))
    yield _wait(L2)
    assert all(len(keys) > 0 for keys in s.has_what.values())
    ncores2 = s.ncores.copy()

    yield c._restart()

    L = c.map(inc, range(10))
    yield _wait(L)
    assert all(len(keys) > 0 for keys in s.has_what.values())

    assert not (set(ncores2) & set(s.ncores))  # no overlap
예제 #7
0
def test_timeouts(c, s, a, b):
    sub = Sub('a', client=c, worker=None)
    start = time()
    with pytest.raises(TimeoutError):
        yield sub.get(timeout=0.1)
    stop = time()
    assert stop - start < 1
예제 #8
0
def test_closing_scheduler_closes_workers(s, a, b):
    yield s.close()

    start = time()
    while a.status != 'closed' or b.status != 'closed':
        yield gen.sleep(0.01)
        assert time() < start + 2
예제 #9
0
def test_speed(c, s, a, b):
    """
    This tests how quickly we can move messages back and forth

    This is mostly a test of latency.

    Interestingly this runs 10x slower on Python 2
    """
    def pingpong(a, b, start=False, n=1000, msg=1):
        sub = Sub(a)
        pub = Pub(b)

        while not pub.subscribers:
            sleep(0.01)

        if start:
            pub.put(msg)  # other sub may not have started yet

        for i in range(n):
            msg = next(sub)
            pub.put(msg)
            # if i % 100 == 0:
            #     print(a, b, i)
        return n

    import numpy as np
    x = np.random.random(1000)

    x = c.submit(pingpong, 'a', 'b', start=True, msg=x, n=100)
    y = c.submit(pingpong, 'b', 'a', n=100)

    start = time()
    yield c.gather([x, y])
    stop = time()
예제 #10
0
def test_close_nanny(c, s, a, b):
    assert len(s.workers) == 2

    assert a.process.is_alive()
    a_worker_address = a.worker_address
    start = time()
    yield s.close_worker(worker=a_worker_address)

    assert len(s.workers) == 1
    assert a_worker_address not in s.workers

    start = time()
    while a.is_alive():
        yield gen.sleep(0.1)
        assert time() < start + 5

    assert a.pid is None

    for i in range(10):
        yield gen.sleep(0.1)
        assert len(s.workers) == 1
        assert not a.is_alive()
        assert a.pid is None

    while a.status != 'closed':
        yield gen.sleep(0.05)
        assert time() < start + 10
예제 #11
0
def test_scheduler_as_center():
    s = Scheduler(validate=True)
    done = s.start(0)
    a = Worker('127.0.0.1', s.port, ip='127.0.0.1', ncores=1)
    a.data.update({'x': 1, 'y': 2})
    b = Worker('127.0.0.1', s.port, ip='127.0.0.1', ncores=2)
    b.data.update({'y': 2, 'z': 3})
    c = Worker('127.0.0.1', s.port, ip='127.0.0.1', ncores=3)
    yield [w._start(0) for w in [a, b, c]]

    assert s.ncores == {w.address: w.ncores for w in [a, b, c]}
    assert not s.who_has

    s.update_graph(tasks={'a': dumps_task((inc, 1))},
                   keys=['a'],
                   dependencies={'a': []})
    start = time()
    while not 'a' in s.who_has:
        assert time() - start < 5
        yield gen.sleep(0.01)
    assert 'a' in a.data or 'a' in b.data or 'a' in c.data

    with ignoring(StreamClosedError):
        yield [w._close() for w in [a, b, c]]

    assert s.ncores == {}
    assert s.who_has == {}

    yield s.close()
예제 #12
0
def test_race(c, s, *workers):
    NITERS = 50

    def f(i):
        with worker_client() as c:
            v = Variable('x', client=c)
            for _ in range(NITERS):
                future = v.get()
                x = future.result()
                y = c.submit(inc, x)
                v.set(y)
                sleep(0.01 * random.random())
            result = v.get().result()
            sleep(0.1)  # allow fire-and-forget messages to clear
            return result

    v = Variable('x', client=c)
    x = yield c.scatter(1)
    yield v.set(x)

    futures = c.map(f, range(15))
    results = yield c.gather(futures)
    assert all(r > NITERS * 0.8 for r in results)

    start = time()
    while len(s.wants_what['variable-x']) != 1:
        yield gen.sleep(0.01)
        assert time() - start < 2
예제 #13
0
def test_counters(c, s, a, b):
    pytest.importorskip('crick')
    while 'tick-duration' not in a.digests:
        yield gen.sleep(0.01)
    aa = Counters(a)

    aa.update()
    yield gen.sleep(0.1)
    aa.update()

    start = time()
    while not len(aa.digest_sources['tick-duration'][0].data['x']):
        yield gen.sleep(1)
        assert time() < start + 5

    a.digests['foo'].add(1)
    a.digests['foo'].add(2)
    aa.add_digest_figure('foo')

    a.counters['bar'].add(1)
    a.counters['bar'].add(2)
    a.counters['bar'].add(2)
    aa.add_counter_figure('bar')

    for x in [aa.counter_sources.values(), aa.digest_sources.values()]:
        for y in x:
            for z in y.values():
                assert len(set(map(len, z.data.values()))) == 1
예제 #14
0
def test_compute_sync(client):
    @dask.delayed
    def f(n, counter):
        assert isinstance(counter, Actor), type(counter)
        for i in range(n):
            counter.increment().result()

    @dask.delayed
    def check(counter, blanks):
        return counter.n

    counter = dask.delayed(Counter)()
    values = [f(i, counter) for i in range(5)]
    final = check(counter, values)

    result = final.compute(actors=counter)
    assert result == 0 + 1 + 2 + 3 + 4

    def check(dask_worker):
        return len(dask_worker.data) + len(dask_worker.actors)

    start = time()
    while any(client.run(check).values()):
        sleep(0.01)
        assert time() < start + 2
예제 #15
0
def test_adapt_then_manual(loop):
    """ We can revert from adaptive, back to manual """
    with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop,
                      diagnostics_port=False, processes=False, n_workers=8) as cluster:
        sleep(0.1)
        cluster.adapt(minimum=0, maximum=4, interval='10ms')

        start = time()
        while cluster.scheduler.workers or cluster.workers:
            sleep(0.1)
            assert time() < start + 5

        assert not cluster.workers

        with Client(cluster) as client:

            futures = client.map(slowinc, range(1000), delay=0.1)
            sleep(0.2)

            cluster._adaptive.stop()
            sleep(0.2)

            cluster.scale(2)

            start = time()
            while len(cluster.scheduler.workers) != 2:
                sleep(0.1)
                assert time() < start + 5
예제 #16
0
def test_close_on_disconnect(s, w):
    yield s.close()

    start = time()
    while w.status != 'closed':
        yield gen.sleep(0.01)
        assert time() < start + 5
예제 #17
0
def test_adaptive_local_cluster_multi_workers():
    loop = IOLoop.current()
    cluster = LocalCluster(0, scheduler_port=0, silence_logs=False, nanny=False,
                           diagnostics_port=None, loop=loop, start=False)
    cluster.scheduler.allowed_failures = 1000
    alc = Adaptive(cluster.scheduler, cluster, interval=100)
    c = Client(cluster, start=False, loop=loop)
    yield c._start()

    futures = c.map(slowinc, range(100), delay=0.01)

    start = time()
    while not cluster.scheduler.worker_info:
        yield gen.sleep(0.01)
        assert time() < start + 15

    yield c._gather(futures)
    del futures

    start = time()
    while cluster.workers:
        yield gen.sleep(0.01)
        assert time() < start + 5

    assert not cluster.workers
    yield gen.sleep(0.2)
    assert not cluster.workers

    futures = c.map(slowinc, range(100), delay=0.01)
    yield c._gather(futures)

    yield c._shutdown()
    yield cluster._close()
예제 #18
0
def test_feed_setup_teardown(s, a, b):
    def setup(scheduler):
        return 1

    def func(scheduler, state):
        assert state == 1
        return 'OK'

    def teardown(scheduler, state):
        scheduler.flag = 'done'

    stream = yield connect(s.ip, s.port)
    yield write(stream, {'op': 'feed',
                         'function': dumps(func),
                         'setup': dumps(setup),
                         'teardown': dumps(teardown),
                         'interval': 0.01})

    for i in range(5):
        response = yield read(stream)
        assert response == 'OK'

    close(stream)
    start = time()
    while not hasattr(s, 'flag'):
        yield gen.sleep(0.01)
        assert time() - start < 5
예제 #19
0
def test_failed_worker_without_warning(c, s, a, b):
    L = c.map(inc, range(10))
    yield wait(L)

    original_pid = a.pid
    with ignoring(CommClosedError):
        yield c._run(os._exit, 1, workers=[a.worker_address])
    start = time()
    while a.pid == original_pid:
        yield gen.sleep(0.01)
        assert time() - start < 10

    yield gen.sleep(0.5)

    start = time()
    while len(s.ncores) < 2:
        yield gen.sleep(0.01)
        assert time() - start < 10

    yield wait(L)

    L2 = c.map(inc, range(10, 20))
    yield wait(L2)
    assert all(len(keys) > 0 for keys in s.has_what.values())
    ncores2 = dict(s.ncores)

    yield c._restart()

    L = c.map(inc, range(10))
    yield wait(L)
    assert all(len(keys) > 0 for keys in s.has_what.values())

    assert not (set(ncores2) & set(s.ncores))  # no overlap
예제 #20
0
def test_cleanup_repeated_tasks(c, s, a, b):
    class Foo(object):
        pass

    s.extensions['stealing']._pc.callback_time = 20
    yield c.submit(slowidentity, -1, delay=0.1)
    objects = [c.submit(Foo, pure=False, workers=a.address) for _ in range(50)]

    x = c.map(slowidentity, objects, workers=a.address, allow_other_workers=True,
              delay=0.05)
    del objects
    yield wait(x)
    assert a.data and b.data
    assert len(a.data) + len(b.data) > 10
    ws = weakref.WeakSet()
    ws.update(a.data.values())
    ws.update(b.data.values())
    del x

    start = time()
    while a.data or b.data:
        yield gen.sleep(0.01)
        assert time() < start + 1

    assert not s.who_has
    assert not any(s.has_what.values())

    assert not list(ws)
예제 #21
0
def test_scale_retires_workers():
    class MyCluster(LocalCluster):
        def scale_down(self, *args, **kwargs):
            pass

    loop = IOLoop.current()
    cluster = yield MyCluster(0, scheduler_port=0, processes=False,
                              silence_logs=False, diagnostics_port=None,
                              loop=loop, asynchronous=True)
    c = yield Client(cluster, loop=loop, asynchronous=True)

    assert not cluster.workers

    yield cluster.scale(2)

    start = time()
    while len(cluster.scheduler.workers) != 2:
        yield gen.sleep(0.01)
        assert time() < start + 3

    yield cluster.scale(1)

    start = time()
    while len(cluster.scheduler.workers) != 1:
        yield gen.sleep(0.01)
        assert time() < start + 3

    yield c.close()
    yield cluster.close()
예제 #22
0
def test_worker_who_has_clears_after_failed_connection(c, s, a, b):
    n = Nanny(s.ip, s.port, ncores=2, loop=s.loop)
    n.start(0)

    start = time()
    while len(s.ncores) < 3:
        yield gen.sleep(0.01)
        assert time() < start + 5

    futures = c.map(slowinc, range(20), delay=0.01,
                    key=['f%d' % i for i in range(20)])
    yield wait(futures)

    result = yield c.submit(sum, futures, workers=a.address)
    for dep in set(a.dep_state) - set(a.task_state):
        a.release_dep(dep, report=True)

    n_worker_address = n.worker_address
    with ignoring(CommClosedError):
        yield c._run(os._exit, 1, workers=[n_worker_address])

    while len(s.workers) > 2:
        yield gen.sleep(0.01)

    total = c.submit(sum, futures, workers=a.address)
    yield total

    assert not a.has_what.get(n_worker_address)
    assert not any(n_worker_address in s for s in a.who_has.values())

    yield n._close()
def test_secede_rejoin_busy():
    with ThreadPoolExecutor(2) as e:

        def f():
            assert threading.current_thread() in e._threads
            secede()
            sleep(0.1)
            assert threading.current_thread() not in e._threads
            rejoin()
            assert len(e._threads) == 2
            assert threading.current_thread() in e._threads
            return threading.current_thread()

        future = e.submit(f)
        L = [e.submit(sleep, 0.2) for i in range(10)]
        start = time()
        special_thread = future.result()
        stop = time()

        assert 0.1 < stop - start < 0.3

        assert len(e._threads) == 2
        assert special_thread in e._threads

        def f():
            sleep(0.01)
            return threading.current_thread()

        futures = [e.submit(f) for _ in range(10)]
        assert special_thread in {future.result() for future in futures}
예제 #24
0
def test_broken_worker_during_computation(c, s, a, b):
    n = Nanny(s.ip, s.port, ncores=2, loop=s.loop)
    n.start(0)

    start = time()
    while len(s.ncores) < 3:
        yield gen.sleep(0.01)
        assert time() < start + 5

    L = c.map(inc, range(256))
    for i in range(8):
        L = c.map(add, *zip(*partition_all(2, L)))

    from random import random
    yield gen.sleep(random() / 2)
    with ignoring(OSError):
        n.process.terminate()
    yield gen.sleep(random() / 2)
    with ignoring(OSError):
        n.process.terminate()

    result = yield c._gather(L)
    assert isinstance(result[0], int)

    yield n._close()
예제 #25
0
def test_timeout(c, s, a, b):
    v = Variable('v')

    start = time()
    with pytest.raises(gen.TimeoutError):
        yield v.get(timeout=0.1)
    stop = time()
    assert 0.1 < stop - start < 2.0
예제 #26
0
def test_heartbeats(c, s, a, b):
    x = s.workers[a.address].last_seen
    start = time()
    yield gen.sleep(a.periodic_callbacks['heartbeat'].callback_time / 1000 + 0.1)
    while s.workers[a.address].last_seen == x:
        yield gen.sleep(0.01)
        assert time() < start + 2
    assert a.periodic_callbacks['heartbeat'].callback_time < 1000
예제 #27
0
 def f(block, ps=None):
     start = time()
     params = ps.get_data(separate_thread=False).result()
     stop = time()
     update = (block - params).mean(axis=0)
     ps.update(update, separate_thread=False)
     print(format_time(stop - start))
     return np.array([[stop - start]])
예제 #28
0
def test_worker_doesnt_await_task_completion(loop):
    with cluster(nanny=True, nworkers=1) as (s, [w]):
        with Client(s['address'], loop=loop) as c:
            future = c.submit(sleep, 100)
            sleep(0.1)
            start = time()
            c.restart()
            stop = time()
            assert stop - start < 5
예제 #29
0
def test_scheduler_file(loop, nanny):
    with tmpfile() as fn:
        with popen(['dask-scheduler', '--no-bokeh', '--scheduler-file', fn]) as sched:
            with popen(['dask-worker', '--scheduler-file', fn, nanny, '--no-bokeh']):
                with Client(scheduler_file=fn, loop=loop) as c:
                    start = time()
                    while not c.scheduler_info()['workers']:
                        sleep(0.1)
                        assert time() < start + 10
예제 #30
0
def test_multiple_workers(loop):
    with popen(['dask-scheduler', '--no-bokeh']) as s:
        with popen(['dask-worker', 'localhost:8786', '--no-bokeh']) as a:
            with popen(['dask-worker', 'localhost:8786', '--no-bokeh']) as b:
                with Client('127.0.0.1:%d' % Scheduler.default_port, loop=loop) as c:
                    start = time()
                    while len(c.ncores()) < 2:
                        sleep(0.1)
                        assert time() < start + 10
예제 #31
0
 async def create_and_destroy_worker(delay):
     start = time()
     while time() < start + 5:
         async with Nanny(s.address, nthreads=2):
             await asyncio.sleep(delay)
         print("Killed nanny")
예제 #32
0
 def time_left():
     deadline = start + timeout
     return max(0, deadline - time())
def test_simple():
    to_child = mp_context.Queue()
    from_child = mp_context.Queue()

    proc = AsyncProcess(target=feed, args=(to_child, from_child))
    assert not proc.is_alive()
    assert proc.pid is None
    assert proc.exitcode is None
    assert not proc.daemon
    proc.daemon = True
    assert proc.daemon

    wr1 = weakref.ref(proc)
    wr2 = weakref.ref(proc._process)

    # join() before start()
    with pytest.raises(AssertionError):
        yield proc.join()

    yield proc.start()
    assert proc.is_alive()
    assert proc.pid is not None
    assert proc.exitcode is None

    t1 = time()
    yield proc.join(timeout=0.02)
    dt = time() - t1
    assert 0.2 >= dt >= 0.01
    assert proc.is_alive()
    assert proc.pid is not None
    assert proc.exitcode is None

    # setting daemon attribute after start()
    with pytest.raises(AssertionError):
        proc.daemon = False

    to_child.put(5)
    assert from_child.get() == 5

    # child should be stopping now
    t1 = time()
    yield proc.join(timeout=10)
    dt = time() - t1
    assert dt <= 1.0
    assert not proc.is_alive()
    assert proc.pid is not None
    assert proc.exitcode == 0

    # join() again
    t1 = time()
    yield proc.join()
    dt = time() - t1
    assert dt <= 0.6

    del proc
    gc.collect()
    start = time()
    while wr1() is not None and time() < start + 1:
        # Perhaps the GIL switched before _watch_process() exit,
        # help it a little
        sleep(0.001)
        gc.collect()
    if wr1() is not None:
        # Help diagnosing
        from types import FrameType

        p = wr1()
        if p is not None:
            rc = sys.getrefcount(p)
            refs = gc.get_referrers(p)
            del p
            print("refs to proc:", rc, refs)
            frames = [r for r in refs if isinstance(r, FrameType)]
            for i, f in enumerate(frames):
                print(
                    "frames #%d:" % i,
                    f.f_code.co_name,
                    f.f_code.co_filename,
                    sorted(f.f_locals),
                )
        pytest.fail("AsyncProcess should have been destroyed")
    t1 = time()
    while wr2() is not None:
        yield gen.sleep(0.01)
        gc.collect()
        dt = time() - t1
        assert dt < 2.0
예제 #34
0
    def update(self):
        with self.proc.oneshot():
            cpu = self.proc.cpu_percent()
            memory = self.get_process_memory()
        now = time()

        self.cpu.append(cpu)
        self.memory.append(memory)
        self.time.append(now)
        self.count += 1

        result = {
            "cpu": cpu,
            "memory": memory,
            "time": now,
            "count": self.count
        }

        if self._collect_net_io_counters:
            try:
                ioc = psutil.net_io_counters()
            except Exception:
                pass
            else:
                last = self._last_io_counters
                duration = now - self.last_time
                read_bytes = (ioc.bytes_recv - last.bytes_recv) / (duration
                                                                   or 0.5)
                write_bytes = (ioc.bytes_sent - last.bytes_sent) / (duration
                                                                    or 0.5)
                self.last_time = now
                self._last_io_counters = ioc
                self.read_bytes.append(read_bytes)
                self.write_bytes.append(write_bytes)
                result["read_bytes"] = read_bytes
                result["write_bytes"] = write_bytes

        if self._collect_disk_io_counters:
            try:
                disk_ioc = psutil.disk_io_counters()
            except Exception:
                pass
            else:
                last_disk = self._last_disk_io_counters
                duration_disk = now - self.last_time_disk
                read_bytes_disk = (disk_ioc.read_bytes - last_disk.read_bytes
                                   ) / (duration_disk or 0.5)
                write_bytes_disk = (disk_ioc.write_bytes -
                                    last_disk.write_bytes) / (duration_disk
                                                              or 0.5)
                self.last_time_disk = now
                self._last_disk_io_counters = disk_ioc
                self.read_bytes_disk.append(read_bytes_disk)
                self.write_bytes_disk.append(write_bytes_disk)
                result["read_bytes_disk"] = read_bytes_disk
                result["write_bytes_disk"] = write_bytes_disk

        if not WINDOWS:
            num_fds = self.proc.num_fds()
            self.num_fds.append(num_fds)
            result["num_fds"] = num_fds

        if nvml.device_get_count() > 0:
            gpu_metrics = nvml.real_time()
            self.gpu_utilization.append(gpu_metrics["utilization"])
            self.gpu_memory_used.append(gpu_metrics["memory-used"])
            result["gpu_utilization"] = gpu_metrics["utilization"]
            result["gpu_memory_used"] = gpu_metrics["memory-used"]

        return result
예제 #35
0
 def _cycle_ticks(self):
     if not self._tick_counter:
         return
     last, self._tick_count_last = self._tick_count_last, time()
     count, self._tick_counter = self._tick_counter, 0
     self._tick_interval_observed = (time() - last) / (count or 1)
예제 #36
0
    def __init__(
        self,
        handlers,
        blocked_handlers=None,
        stream_handlers=None,
        connection_limit=512,
        deserialize=True,
        serializers=None,
        deserializers=None,
        connection_args=None,
        timeout=None,
        io_loop=None,
    ):
        self.handlers = {
            "identity": self.identity,
            "echo": self.echo,
            "connection_stream": self.handle_stream,
            "dump_state": self._to_dict,
        }
        self.handlers.update(handlers)
        if blocked_handlers is None:
            blocked_handlers = dask.config.get(
                "distributed.%s.blocked-handlers" % type(self).__name__.lower(), []
            )
        self.blocked_handlers = blocked_handlers
        self.stream_handlers = {}
        self.stream_handlers.update(stream_handlers or {})

        self.id = type(self).__name__ + "-" + str(uuid.uuid4())
        self._address = None
        self._listen_address = None
        self._port = None
        self._comms = {}
        self.deserialize = deserialize
        self.monitor = SystemMonitor()
        self.counters = None
        self.digests = None
        self._ongoing_coroutines = weakref.WeakSet()
        self._event_finished = asyncio.Event()

        self.listeners = []
        self.io_loop = io_loop or IOLoop.current()
        self.loop = self.io_loop

        if not hasattr(self.io_loop, "profile"):
            ref = weakref.ref(self.io_loop)

            def stop():
                loop = ref()
                return loop is None or loop.asyncio_loop.is_closed()

            self.io_loop.profile = profile.watch(
                omit=("profile.py", "selectors.py"),
                interval=dask.config.get("distributed.worker.profile.interval"),
                cycle=dask.config.get("distributed.worker.profile.cycle"),
                stop=stop,
            )

        # Statistics counters for various events
        with suppress(ImportError):
            from distributed.counter import Digest

            self.digests = defaultdict(partial(Digest, loop=self.io_loop))

        from distributed.counter import Counter

        self.counters = defaultdict(partial(Counter, loop=self.io_loop))

        self.periodic_callbacks = dict()

        pc = PeriodicCallback(
            self.monitor.update,
            parse_timedelta(
                dask.config.get("distributed.admin.system-monitor.interval")
            )
            * 1000,
        )
        self.periodic_callbacks["monitor"] = pc

        self._last_tick = time()
        self._tick_counter = 0
        self._tick_count = 0
        self._tick_count_last = time()
        self._tick_interval = parse_timedelta(
            dask.config.get("distributed.admin.tick.interval"), default="ms"
        )
        self._tick_interval_observed = self._tick_interval
        self.periodic_callbacks["tick"] = PeriodicCallback(
            self._measure_tick, self._tick_interval * 1000
        )
        self.periodic_callbacks["ticks"] = PeriodicCallback(
            self._cycle_ticks,
            parse_timedelta(dask.config.get("distributed.admin.tick.cycle")) * 1000,
        )

        self.thread_id = 0

        def set_thread_ident():
            self.thread_id = threading.get_ident()

        self.io_loop.add_callback(set_thread_ident)
        self._startup_lock = asyncio.Lock()
        self.status = Status.undefined

        self.rpc = ConnectionPool(
            limit=connection_limit,
            deserialize=deserialize,
            serializers=serializers,
            deserializers=deserializers,
            connection_args=connection_args,
            timeout=timeout,
            server=self,
        )

        self.__stopped = False
예제 #37
0
async def check_connect_timeout(addr):
    t1 = time()
    with pytest.raises(IOError):
        await connect(addr, timeout=0.15)
    dt = time() - t1
    assert 1 >= dt >= 0.1
예제 #38
0
def check_connect_timeout(addr):
    t1 = time()
    with pytest.raises(IOError):
        yield connect(addr, timeout=0.15)
    dt = time() - t1
    assert 0.5 >= dt >= 0.1
예제 #39
0
async def connect(
    addr, timeout=None, deserialize=True, handshake_overrides=None, **connection_args
):
    """
    Connect to the given address (a URI such as ``tcp://127.0.0.1:1234``)
    and yield a ``Comm`` object.  If the connection attempt fails, it is
    retried until the *timeout* is expired.
    """
    if timeout is None:
        timeout = dask.config.get("distributed.comm.timeouts.connect")
    timeout = parse_timedelta(timeout, default="seconds")

    scheme, loc = parse_address(addr)
    backend = registry.get_backend(scheme)
    connector = backend.get_connector()
    comm = None

    start = time()

    def time_left():
        deadline = start + timeout
        return max(0, deadline - time())

    backoff_base = 0.01
    attempt = 0

    # Prefer multiple small attempts than one long attempt. This should protect
    # primarily from DNS race conditions
    # gh3104, gh4176, gh4167
    intermediate_cap = timeout / 5
    active_exception = None
    while time_left() > 0:
        try:
            comm = await asyncio.wait_for(
                connector.connect(loc, deserialize=deserialize, **connection_args),
                timeout=min(intermediate_cap, time_left()),
            )
            break
        except FatalCommClosedError:
            raise
        # Note: CommClosed inherits from OSError
        except (asyncio.TimeoutError, OSError) as exc:
            active_exception = exc

            # As descibed above, the intermediate timeout is used to distributed
            # initial, bulk connect attempts homogeneously. In particular with
            # the jitter upon retries we should not be worred about overloading
            # any more DNS servers
            intermediate_cap = timeout
            # FullJitter see https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/

            upper_cap = min(time_left(), backoff_base * (2**attempt))
            backoff = random.uniform(0, upper_cap)
            attempt += 1
            logger.debug(
                "Could not connect to %s, waiting for %s before retrying", loc, backoff
            )
            await asyncio.sleep(backoff)
    else:
        raise OSError(
            f"Timed out trying to connect to {addr} after {timeout} s"
        ) from active_exception

    local_info = {
        **comm.handshake_info(),
        **(handshake_overrides or {}),
    }
    try:
        # This would be better, but connections leak if worker is closed quickly
        # write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
        handshake = await asyncio.wait_for(comm.read(), time_left())
        await asyncio.wait_for(comm.write(local_info), time_left())
    except Exception as exc:
        with suppress(Exception):
            await comm.close()
        raise OSError(
            f"Timed out during handshake while connecting to {addr} after {timeout} s"
        ) from exc

    comm.remote_info = handshake
    comm.remote_info["address"] = comm._peer_addr
    comm.local_info = local_info
    comm.local_info["address"] = comm._local_addr

    comm.handshake_options = comm.handshake_configuration(
        comm.local_info, comm.remote_info
    )
    return comm
예제 #40
0
def main(scheduler, host, worker_port, http_port, nanny_port, nthreads, nprocs,
         nanny, name, memory_limit, pid_file, temp_filename, reconnect,
         resources, bokeh, bokeh_port, local_directory, scheduler_file,
         death_timeout):
    if nanny:
        port = nanny_port
    else:
        port = worker_port

    if nprocs > 1 and worker_port != 0:
        logger.error("Failed to launch worker.  You cannot use the --port argument when nprocs > 1.")
        exit(1)

    if nprocs > 1 and name:
        logger.error("Failed to launch worker.  You cannot use the --name argument when nprocs > 1.")
        exit(1)

    if not nthreads:
        nthreads = _ncores // nprocs

    if pid_file:
        with open(pid_file, 'w') as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)
        atexit.register(del_pid_file)

    services = {('http', http_port): HTTPWorker}

    if bokeh:
        try:
            from distributed.bokeh.worker import BokehWorker
        except ImportError:
            pass
        else:
            services[('bokeh', bokeh_port)] = BokehWorker

    if resources:
        resources = resources.replace(',', ' ').split()
        resources = dict(pair.split('=') for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    if nanny:
        kwargs = {'worker_port': worker_port}
        t = Nanny
    else:
        kwargs = {}
        if nanny_port:
            kwargs['service_ports'] = {'nanny': nanny_port}
        t = Worker

    if scheduler_file:
        while not os.path.exists(scheduler_file):
            sleep(0.01)
        for i in range(10):
            try:
                with open(scheduler_file) as f:
                    cfg = json.load(f)
                scheduler = cfg['address']
                break
            except (ValueError, KeyError):  # race with scheduler on file
                sleep(0.01)

    if not scheduler:
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    nannies = [t(scheduler, ncores=nthreads,
                 services=services, name=name, loop=loop, resources=resources,
                 memory_limit=memory_limit, reconnect=reconnect,
                 local_dir=local_directory, death_timeout=death_timeout,
                 **kwargs)
               for i in range(nprocs)]

    for n in nannies:
        if host:
            n.start((host, port))
        else:
            n.start(port)
        if t is Nanny:
            global_nannies.append(n)

    if temp_filename:
        @gen.coroutine
        def f():
            while nannies[0].status != 'running':
                yield gen.sleep(0.01)
            import json
            msg = {'port': nannies[0].port,
                   'local_directory': nannies[0].local_dir}
            with open(temp_filename, 'w') as f:
                json.dump(msg, f)
        loop.add_callback(f)

    @gen.coroutine
    def run():
        while all(n.status != 'closed' for n in nannies):
            yield gen.sleep(0.2)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")
        loop.close()

    # Clean exit: unregister all workers from scheduler

    loop2 = IOLoop()

    @gen.coroutine
    def f():
        with rpc(nannies[0].scheduler.address) as scheduler:
            if nanny:
                yield gen.with_timeout(
                        timeout=timedelta(seconds=2),
                        future=All([scheduler.unregister(address=n.worker_address, close=True)
                                   for n in nannies if n.process and n.worker_address]),
                        io_loop=loop2)

    loop2.run_sync(f)

    if nanny:
        for n in nannies:
            if isalive(n.process):
                n.process.terminate()

    if nanny:
        start = time()
        while (any(isalive(n.process) for n in nannies)
                and time() < start + 1):
            sleep(0.1)

    for nanny in nannies:
        nanny.stop()
예제 #41
0
from time import sleep
from distributed import Client
from distributed.metrics import time

from dask_mpi import initialize

initialize()

with Client() as c:

    start = time()
    while len(c.scheduler_info()["workers"]) != 2:
        assert time() < start + 10
        sleep(0.2)

    assert c.submit(lambda x: x + 1, 10).result() == 11
    assert c.submit(lambda x: x + 1, 20, workers=2).result() == 21
예제 #42
0
def test_heartbeats(c, s, a, b):
    pytest.importorskip('psutil')
    start = time()
    while not all(s.worker_info[w].get('memory-rss') for w in s.workers):
        yield gen.sleep(0.01)
        assert time() < start + 2
예제 #43
0
def test_dask_cuda_worker_ucx_net_devices(loop):  # noqa: F811
    net_devices = _get_dgx_net_devices()

    sched_env = os.environ.copy()
    sched_env["UCX_TLS"] = "rc,sockcm,tcp,cuda_copy"
    sched_env["UCX_SOCKADDR_TLS_PRIORITY"] = "sockcm"

    with subprocess.Popen(
        [
            "dask-scheduler",
            "--protocol",
            "ucx",
            "--host",
            "127.0.0.1",
            "--port",
            "9379",
            "--no-dashboard",
        ],
            env=sched_env,
    ) as sched_proc:
        # Scheduler with UCX will take a few seconds to fully start
        sleep(5)

        with subprocess.Popen([
                "dask-cuda-worker",
                "ucx://127.0.0.1:9379",
                "--host",
                "127.0.0.1",
                "--enable-infiniband",
                "--net-devices",
                "auto",
                "--no-dashboard",
        ], ) as worker_proc:
            with Client("ucx://127.0.0.1:9379", loop=loop) as client:

                start = time()
                while True:
                    if len(client.scheduler_info()
                           ["workers"]) == get_gpu_count():
                        break
                    else:
                        assert time() - start < 10
                        sleep(0.1)

                worker_net_devices = client.run(
                    lambda: ucp.get_config()["NET_DEVICES"])
                cuda_visible_devices = client.run(
                    lambda: os.environ["CUDA_VISIBLE_DEVICES"])

                for i, v in enumerate(
                        zip(worker_net_devices.values(),
                            cuda_visible_devices.values())):
                    net_dev = v[0]
                    dev_idx = int(v[1].split(",")[0])
                    assert net_dev == net_devices[dev_idx]

            # A dask-worker with UCX protocol will not close until some work
            # is dispatched, therefore we kill the worker and scheduler to
            # ensure timely closing.
            worker_proc.kill()
            sched_proc.kill()
예제 #44
0
def warn_on_duration(duration, msg):
    start = time()
    yield
    stop = time()
    if stop - start > _parse_timedelta(duration):
        warnings.warn(msg, stacklevel=2)
예제 #45
0
def wait_for_cores(c, ncores=1):
    start = time()
    while len(c.ncores()) < 1:
        sleep(0.1)
        assert time() < start + 10
예제 #46
0
 def stop():
     return time() > start + 0.500
예제 #47
0
 def stop():
     return metrics.time() > start + 0.500
예제 #48
0
def test_idle_timeout(loop):
    start = time()
    runner = CliRunner()
    runner.invoke(distributed.cli.dask_scheduler.main, ["--idle-timeout", "1s"])
    stop = time()
    assert 1 < stop - start < 10
예제 #49
0
def async_ssh(cmd_dict):
    import paramiko
    from paramiko.buffered_pipe import PipeTimeout
    from paramiko.ssh_exception import PasswordRequiredException, SSHException

    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())

    retries = 0
    while True:  # Be robust to transient SSH failures.
        try:
            # Set paramiko logging to WARN or higher to squelch INFO messages.
            logging.getLogger("paramiko").setLevel(logging.WARN)

            ssh.connect(
                hostname=cmd_dict["address"],
                username=cmd_dict["ssh_username"],
                port=cmd_dict["ssh_port"],
                key_filename=cmd_dict["ssh_private_key"],
                compress=True,
                timeout=30,
                banner_timeout=30,
            )  # Helps prevent timeouts when many concurrent ssh connections are opened.
            # Connection successful, break out of while loop
            break

        except (SSHException, PasswordRequiredException) as e:

            print(
                "[ dask-ssh ] : "
                + bcolors.FAIL
                + "SSH connection error when connecting to {addr}:{port} "
                "to run '{cmd}'".format(
                    addr=cmd_dict["address"],
                    port=cmd_dict["ssh_port"],
                    cmd=cmd_dict["cmd"],
                )
                + bcolors.ENDC
            )

            print(
                bcolors.FAIL
                + "               SSH reported this exception: "
                + str(e)
                + bcolors.ENDC
            )

            # Print an exception traceback
            traceback.print_exc()

            # Transient SSH errors can occur when many SSH connections are
            # simultaneously opened to the same server. This makes a few
            # attempts to retry.
            retries += 1
            if retries >= 3:
                print(
                    "[ dask-ssh ] : "
                    + bcolors.FAIL
                    + "SSH connection failed after 3 retries. Exiting."
                    + bcolors.ENDC
                )

                # Connection failed after multiple attempts.  Terminate this thread.
                os._exit(1)

            # Wait a moment before retrying
            print(
                "               "
                + bcolors.FAIL
                + f"Retrying... (attempt {retries}/3)"
                + bcolors.ENDC
            )

            sleep(1)

    # Execute the command, and grab file handles for stdout and stderr. Note
    # that we run the command using the user's default shell, but force it to
    # run in an interactive login shell, which hopefully ensures that all of the
    # user's normal environment variables (via the dot files) have been loaded
    # before the command is run. This should help to ensure that important
    # aspects of the environment like PATH and PYTHONPATH are configured.

    print("[ {label} ] : {cmd}".format(label=cmd_dict["label"], cmd=cmd_dict["cmd"]))
    stdin, stdout, stderr = ssh.exec_command(
        "$SHELL -i -c '" + cmd_dict["cmd"] + "'", get_pty=True
    )

    # Set up channel timeout (which we rely on below to make readline() non-blocking)
    channel = stdout.channel
    channel.settimeout(0.1)

    def read_from_stdout():
        """
        Read stdout stream, time out if necessary.
        """
        try:
            line = stdout.readline()
            while len(line) > 0:  # Loops until a timeout exception occurs
                line = line.rstrip()
                logger.debug("stdout from ssh channel: %s", line)
                cmd_dict["output_queue"].put(
                    "[ {label} ] : {output}".format(
                        label=cmd_dict["label"], output=line
                    )
                )
                line = stdout.readline()
        except (PipeTimeout, socket.timeout):
            pass

    def read_from_stderr():
        """
        Read stderr stream, time out if necessary.
        """
        try:
            line = stderr.readline()
            while len(line) > 0:
                line = line.rstrip()
                logger.debug("stderr from ssh channel: %s", line)
                cmd_dict["output_queue"].put(
                    "[ {label} ] : ".format(label=cmd_dict["label"])
                    + bcolors.FAIL
                    + line
                    + bcolors.ENDC
                )
                line = stderr.readline()
        except (PipeTimeout, socket.timeout):
            pass

    def communicate():
        """
        Communicate a little bit, without blocking too long.
        Return True if the command ended.
        """
        read_from_stdout()
        read_from_stderr()

        # Check to see if the process has exited. If it has, we let this thread
        # terminate.
        if channel.exit_status_ready():
            exit_status = channel.recv_exit_status()
            cmd_dict["output_queue"].put(
                "[ {label} ] : ".format(label=cmd_dict["label"])
                + bcolors.FAIL
                + "remote process exited with exit status "
                + str(exit_status)
                + bcolors.ENDC
            )
            return True

    # Get transport to current SSH client
    transport = ssh.get_transport()

    # Wait for a message on the input_queue. Any message received signals this
    # thread to shut itself down.
    while cmd_dict["input_queue"].empty():
        # Kill some time so that this thread does not hog the CPU.
        sleep(1.0)
        # Send noise down the pipe to keep connection active
        transport.send_ignore()
        if communicate():
            break

    # Ctrl-C the executing command and wait a bit for command to end cleanly
    start = time()
    while time() < start + 5.0:
        channel.send(b"\x03")  # Ctrl-C
        if communicate():
            break
        sleep(1.0)

    # Shutdown the channel, and close the SSH connection
    channel.close()
    ssh.close()
예제 #50
0
    with pytest.raises(AssertionError):
        assert_worker_story(story, [("foo", ), ("bar", )], strict=True)
    with pytest.raises(AssertionError):
        assert_worker_story(story, [("foo", ), ("baz", {1: 2})], strict=True)
    with pytest.raises(AssertionError):
        assert_worker_story(story, [], strict=True)


@pytest.mark.parametrize(
    "story_factory",
    [
        pytest.param(lambda: [()], id="Missing payload, stimulus_id, ts"),
        pytest.param(lambda: [("foo", )], id="Missing (stimulus_id, ts)"),
        pytest.param(lambda: [("foo", "bar")], id="Missing ts"),
        pytest.param(lambda: [("foo", "bar", "baz")], id="ts is not a float"),
        pytest.param(lambda: [("foo", "bar", time() + 3600)],
                     id="ts is in the future"),
        pytest.param(lambda: [("foo", "bar", time() - 7200)],
                     id="ts is too old"),
        pytest.param(lambda: [("foo", 123, time())],
                     id="stimulus_id is not a str"),
        pytest.param(lambda: [("foo", "", time())],
                     id="stimulus_id is an empty str"),
        pytest.param(lambda: [("", time())], id="no payload"),
        pytest.param(
            lambda: [("foo", "id", time()), ("foo", "id", time() - 10)],
            id="timestamps out of order",
        ),
    ],
)
def test_assert_worker_story_malformed_story(story_factory):
예제 #51
0
def main(scheduler, host, worker_port, http_port, nanny_port, nthreads, nprocs,
         nanny, name, memory_limit, pid_file, temp_filename, reconnect,
         resources, bokeh, bokeh_port):
    if nanny:
        port = nanny_port
    else:
        port = worker_port

    try:
        scheduler_host, scheduler_port = scheduler.split(':')
        scheduler_ip = socket.gethostbyname(scheduler_host)
        scheduler_port = int(scheduler_port)
    except IndexError:
        logger.info("Usage:  dask-worker scheduler_host:scheduler_port")

    if nprocs > 1 and worker_port != 0:
        logger.error(
            "Failed to launch worker.  You cannot use the --port argument when nprocs > 1."
        )
        exit(1)

    if nprocs > 1 and name:
        logger.error(
            "Failed to launch worker.  You cannot use the --name argument when nprocs > 1."
        )
        exit(1)

    if not nthreads:
        nthreads = _ncores // nprocs

    if pid_file:
        with open(pid_file, 'w') as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    services = {('http', http_port): HTTPWorker}

    if bokeh:
        try:
            from distributed.bokeh.worker import BokehWorker
        except ImportError:
            pass
        else:
            services[('bokeh', bokeh_port)] = BokehWorker

    if resources:
        resources = resources.replace(',', ' ').split()
        resources = dict(pair.split('=') for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    if nanny:
        kwargs = {'worker_port': worker_port}
        t = Nanny
    else:
        kwargs = {}
        if nanny_port:
            kwargs['service_ports'] = {'nanny': nanny_port}
        t = Worker

    if host is not None:
        ip = socket.gethostbyname(host)
    else:
        # lookup the ip address of a local interface on a network that
        # reach the scheduler
        ip = get_ip(scheduler_ip, scheduler_port)
    nannies = [
        t(scheduler_ip,
          scheduler_port,
          ncores=nthreads,
          ip=ip,
          services=services,
          name=name,
          loop=loop,
          resources=resources,
          memory_limit=memory_limit,
          reconnect=reconnect,
          **kwargs) for i in range(nprocs)
    ]

    for n in nannies:
        n.start(port)
        if t is Nanny:
            global_nannies.append(n)

    if temp_filename:

        @gen.coroutine
        def f():
            while nannies[0].status != 'running':
                yield gen.sleep(0.01)
            import json
            msg = {
                'port': nannies[0].port,
                'local_directory': nannies[0].local_dir
            }
            with open(temp_filename, 'w') as f:
                json.dump(msg, f)

        loop.add_callback(f)

    @gen.coroutine
    def run():
        while all(n.status != 'closed' for n in nannies):
            yield gen.sleep(0.2)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")
        loop.close()

    loop2 = IOLoop()

    @gen.coroutine
    def f():
        scheduler = rpc(ip=nannies[0].scheduler.ip,
                        port=nannies[0].scheduler.port)
        if nanny:
            yield gen.with_timeout(timedelta(seconds=2),
                                   All([
                                       scheduler.unregister(
                                           address=n.worker_address,
                                           close=True) for n in nannies
                                       if n.process and n.worker_port
                                   ]),
                                   io_loop=loop2)

    loop2.run_sync(f)

    if nanny:
        for n in nannies:
            if isalive(n.process):
                n.process.terminate()

    if nanny:
        start = time()
        while (any(isalive(n.process) for n in nannies)
               and time() < start + 1):
            sleep(0.1)

    for nanny in nannies:
        nanny.stop()