예제 #1
0
    async def setup(self):
        keys = self.keys

        while not keys.issubset(self.scheduler.tasks):
            await asyncio.sleep(0.05)

        tasks = [self.scheduler.tasks[k] for k in keys]

        self.keys = None

        self.scheduler.add_plugin(self)  # subtle race condition here
        self.all_keys, errors = dependent_keys(tasks, complete=self.complete)
        if not self.complete:
            self.keys = self.all_keys.copy()
        else:
            self.keys, _ = dependent_keys(tasks, complete=False)
        self.all_keys.update(keys)
        self.keys |= errors & self.all_keys

        if not self.keys:
            self.stop(exception=None, key=None)

        # Group keys by func name
        self.keys = valmap(set, groupby(self.func, self.keys))
        self.all_keys = valmap(set, groupby(self.func, self.all_keys))
        for k in self.all_keys:
            if k not in self.keys:
                self.keys[k] = set()

        for k in errors:
            self.transition(k, None, "erred", exception=True)
        logger.debug("Set up Progress keys")
예제 #2
0
def counts(scheduler, allprogress):
    return merge(
        {"all": valmap(len, allprogress.all), "nbytes": allprogress.nbytes},
        {
            state: valmap(len, allprogress.state[state])
            for state in ["memory", "erred", "released", "processing"]
        },
    )
예제 #3
0
파일: blockwise.py 프로젝트: stromal/dask
def broadcast_dimensions(argpairs,
                         numblocks,
                         sentinels=(1, (1, )),
                         consolidate=None):
    """Find block dimensions from arguments

    Parameters
    ----------
    argpairs: iterable
        name, ijk index pairs
    numblocks: dict
        maps {name: number of blocks}
    sentinels: iterable (optional)
        values for singleton dimensions
    consolidate: func (optional)
        use this to reduce each set of common blocks into a smaller set

    Examples
    --------
    >>> argpairs = [('x', 'ij'), ('y', 'ji')]
    >>> numblocks = {'x': (2, 3), 'y': (3, 2)}
    >>> broadcast_dimensions(argpairs, numblocks)
    {'i': 2, 'j': 3}

    Supports numpy broadcasting rules

    >>> argpairs = [('x', 'ij'), ('y', 'ij')]
    >>> numblocks = {'x': (2, 1), 'y': (1, 3)}
    >>> broadcast_dimensions(argpairs, numblocks)
    {'i': 2, 'j': 3}

    Works in other contexts too

    >>> argpairs = [('x', 'ij'), ('y', 'ij')]
    >>> d = {'x': ('Hello', 1), 'y': (1, (2, 3))}
    >>> broadcast_dimensions(argpairs, d)
    {'i': 'Hello', 'j': (2, 3)}
    """
    # List like [('i', 2), ('j', 1), ('i', 1), ('j', 2)]
    argpairs2 = [(a, ind) for a, ind in argpairs if ind is not None]
    L = toolz.concat([
        zip(inds, dims) for (x, inds), (x, dims) in toolz.join(
            toolz.first, argpairs2, toolz.first, numblocks.items())
    ])

    g = toolz.groupby(0, L)
    g = dict((k, set([d for i, d in v])) for k, v in g.items())

    g2 = dict(
        (k, v - set(sentinels) if len(v) > 1 else v) for k, v in g.items())

    if consolidate:
        return toolz.valmap(consolidate, g2)

    if g2 and not set(map(len, g2.values())) == set([1]):
        raise ValueError("Shapes do not align %s" % g)

    return toolz.valmap(toolz.first, g2)
예제 #4
0
 def function(scheduler, p):
     result = {
         "all": valmap(len, p.all_keys),
         "remaining": valmap(len, p.keys),
         "status": p.status,
     }
     if p.status == "error":
         result.update(p.extra)
     return result
예제 #5
0
파일: shuffle.py 프로젝트: sighingnow/dask
    def __dask_distributed_unpack__(cls, state, dsk, dependencies,
                                    annotations):
        from distributed.worker import dumps_task

        # msgpack will convert lists into tuples, here
        # we convert them back to lists
        if isinstance(state["column"], tuple):
            state["column"] = list(state["column"])
        if "inputs" in state:
            state["inputs"] = list(state["inputs"])

        # Materialize the layer
        raw = dict(cls(**state))

        # Convert all keys to strings and dump tasks
        raw = {
            stringify(k): stringify_collection_keys(v)
            for k, v in raw.items()
        }
        dsk.update(toolz.valmap(dumps_task, raw))

        # TODO: use shuffle-knowledge to calculate dependencies more efficiently
        dependencies.update(
            {k: keys_in_tasks(dsk, [v], as_list=True)
             for k, v in raw.items()})

        if state["annotations"]:
            cls.unpack_annotations(annotations, state["annotations"],
                                   raw.keys())
예제 #6
0
def test_multibar_complete(s, a, b):
    s.update_graph(
        tasks=valmap(
            dumps_task,
            {
                "x-1": (inc, 1),
                "x-2": (inc, "x-1"),
                "x-3": (inc, "x-2"),
                "y-1": (dec, "x-3"),
                "y-2": (dec, "y-1"),
                "e": (throws, "y-2"),
                "other": (inc, 123),
            },
        ),
        keys=["e"],
        dependencies={
            "x-2": {"x-1"},
            "x-3": {"x-2"},
            "y-1": {"x-3"},
            "y-2": {"y-1"},
            "e": {"y-2"},
        },
    )

    p = MultiProgressWidget(["e"], scheduler=s.address, complete=True)
    yield p.listen()

    assert p._last_response["all"] == {"x": 3, "y": 2, "e": 1}
    assert all(b.value == 1.0 for k, b in p.bars.items() if k != "e")
    assert "3 / 3" in p.bar_texts["x"].value
    assert "2 / 2" in p.bar_texts["y"].value
예제 #7
0
def container_copy(c):
    typ = type(c)
    if typ is list:
        return list(map(container_copy, c))
    if typ is dict:
        return valmap(container_copy, c)
    return c
예제 #8
0
def test_multi_progressbar_widget_after_close(s, a, b):
    s.update_graph(
        tasks=valmap(
            dumps_task,
            {
                "x-1": (inc, 1),
                "x-2": (inc, "x-1"),
                "x-3": (inc, "x-2"),
                "y-1": (dec, "x-3"),
                "y-2": (dec, "y-1"),
                "e": (throws, "y-2"),
                "other": (inc, 123),
            },
        ),
        keys=["e"],
        dependencies={
            "x-2": {"x-1"},
            "x-3": {"x-2"},
            "y-1": {"x-3"},
            "y-2": {"y-1"},
            "e": {"y-2"},
        },
    )

    p = MultiProgressWidget(["x-1", "x-2", "x-3"], scheduler=s.address)
    yield p.listen()

    assert "x" in p.bars
예제 #9
0
파일: test_s3.py 프로젝트: astrojuanlu/dask
def test_compression(s3, fmt, blocksize, s3so):
    if fmt not in compress:
        pytest.skip("compression function not provided")
    s3._cache.clear()
    with s3_context("compress", valmap(compress[fmt], files)):
        if fmt and blocksize:
            with pytest.raises(ValueError):
                read_bytes(
                    "s3://compress/test/accounts.*",
                    compression=fmt,
                    blocksize=blocksize,
                    **s3so
                )
            return
        sample, values = read_bytes(
            "s3://compress/test/accounts.*",
            compression=fmt,
            blocksize=blocksize,
            **s3so
        )
        assert sample.startswith(files[sorted(files)[0]][:10])
        assert sample.endswith(b"\n")

        results = compute(*concat(values))
        assert b"".join(results) == b"".join([files[k] for k in sorted(files)])
예제 #10
0
파일: test_local.py 프로젝트: bigmpc/dask
def test_compression(fmt, blocksize):
    if fmt not in compress:
        pytest.skip("compression function not provided")
    files2 = valmap(compress[fmt], files)
    with filetexts(files2, mode="b"):
        if fmt and blocksize:
            with pytest.raises(ValueError):
                read_bytes(
                    ".test.accounts.*.json",
                    blocksize=blocksize,
                    delimiter=b"\n",
                    compression=fmt,
                )
            return
        sample, values = read_bytes(
            ".test.accounts.*.json",
            blocksize=blocksize,
            delimiter=b"\n",
            compression=fmt,
        )
        assert sample[:5] == files[sorted(files)[0]][:5]
        assert sample.endswith(b"\n")

        results = compute(*concat(values))
        assert b"".join(results) == b"".join([files[k] for k in sorted(files)])
예제 #11
0
async def test_nanny_process_failure(c, s):
    n = await Nanny(s.address, nthreads=2)
    first_dir = n.worker_dir

    assert os.path.exists(first_dir)

    ww = rpc(n.worker_address)
    await ww.update_data(data=valmap(dumps, {"x": 1, "y": 2}))
    pid = n.pid
    assert pid is not None
    with suppress(CommClosedError):
        await c.run(os._exit, 0, workers=[n.worker_address])

    while n.pid == pid:  # wait while process dies and comes back
        await asyncio.sleep(0.01)

    await asyncio.sleep(1)
    while not n.is_alive():  # wait while process comes back
        await asyncio.sleep(0.01)

    # assert n.worker_address != original_address  # most likely

    while n.worker_address not in s.nthreads or n.worker_dir is None:
        await asyncio.sleep(0.01)

    second_dir = n.worker_dir

    await n.close()
    assert not os.path.exists(second_dir)
    assert not os.path.exists(first_dir)
    assert first_dir != n.worker_dir
    await ww.close_rpc()
    s.stop()
예제 #12
0
def test_groupby_with_indexer():
    b = db.from_sequence([[1, 2, 3], [1, 4, 9], [2, 3, 4]])
    result = dict(b.groupby(0))
    assert valmap(sorted, result) == {
        1: [[1, 2, 3], [1, 4, 9]],
        2: [[2, 3, 4]]
    }
예제 #13
0
def _materialized_layer_pack(
    layer: Layer,
    all_keys,
    known_key_dependencies,
    client,
    client_keys,
):
    from ..client import Future

    dsk = dict(layer)

    # Find aliases not in `client_keys` and substitute all matching keys
    # with its Future
    values = {
        k: v
        for k, v in dsk.items()
        if isinstance(v, Future) and k not in client_keys
    }
    if values:
        dsk = subs_multiple(dsk, values)

    # Unpack remote data and record its dependencies
    dsk = {k: unpack_remotedata(v, byte_keys=True) for k, v in layer.items()}
    unpacked_futures = set.union(*[v[1]
                                   for v in dsk.values()]) if dsk else set()
    for future in unpacked_futures:
        if future.client is not client:
            raise ValueError(
                "Inputs contain futures that were created by another client.")
        if stringify(future.key) not in client.futures:
            raise CancelledError(stringify(future.key))
    unpacked_futures_deps = {}
    for k, v in dsk.items():
        if len(v[1]):
            unpacked_futures_deps[k] = {f.key for f in v[1]}
    dsk = {k: v[0] for k, v in dsk.items()}

    # Calculate dependencies without re-calculating already known dependencies
    missing_keys = set(dsk.keys()).difference(known_key_dependencies.keys())
    dependencies = {
        k: keys_in_tasks(all_keys, [dsk[k]], as_list=False)
        for k in missing_keys
    }
    for k, v in unpacked_futures_deps.items():
        dependencies[k] = set(dependencies.get(k, ())) | v

    # The scheduler expect all keys to be strings
    dependencies = {
        stringify(k): [stringify(dep) for dep in deps]
        for k, deps in dependencies.items()
    }
    all_keys = all_keys.union(dsk)
    dsk = {
        stringify(k): stringify(v, exclusive=all_keys)
        for k, v in dsk.items()
    }
    dsk = valmap(dumps_task, dsk)
    return {"dsk": dsk, "dependencies": dependencies}
예제 #14
0
def test_modification_time_read_bytes():
    with s3_context("compress", files):
        _, a = read_bytes("s3://compress/test/accounts.*", anon=True)
        _, b = read_bytes("s3://compress/test/accounts.*", anon=True)

        assert [aa._key for aa in concat(a)] == [bb._key for bb in concat(b)]

    with s3_context("compress", valmap(double, files)):
        _, c = read_bytes("s3://compress/test/accounts.*", anon=True)

    assert [aa._key for aa in concat(a)] != [cc._key for cc in concat(c)]
예제 #15
0
파일: test_local.py 프로젝트: bigmpc/dask
def test_open_files_compression(mode, fmt):
    if fmt not in compress:
        pytest.skip("compression function not provided")
    files2 = valmap(compress[fmt], files)
    with filetexts(files2, mode="b"):
        myfiles = open_files(".test.accounts.*", mode=mode, compression=fmt)
        data = []
        for file in myfiles:
            with file as f:
                data.append(f.read())
        sol = [files[k] for k in sorted(files)]
        if mode == "rt":
            sol = [b.decode() for b in sol]
        assert list(data) == sol
예제 #16
0
def test_read_csv_compression(fmt, blocksize):
    if fmt not in compress:
        pytest.skip("compress function not provided for %s" % fmt)
    files2 = valmap(compress[fmt], csv_files)
    with filetexts(files2, mode="b"):
        if fmt and blocksize:
            with pytest.warns(UserWarning):
                df = dd.read_csv("2014-01-*.csv", compression=fmt, blocksize=blocksize)
        else:
            df = dd.read_csv("2014-01-*.csv", compression=fmt, blocksize=blocksize)
        assert_eq(
            df.compute(scheduler="sync").reset_index(drop=True),
            expected.reset_index(drop=True),
            check_dtype=False,
        )
예제 #17
0
파일: layers.py 프로젝트: nils-braun/dask
    def __dask_distributed_unpack__(cls, state, dsk, dependencies):
        from distributed.worker import dumps_task

        # Expand merge_kwargs
        merge_kwargs = state.pop("merge_kwargs", {})
        state.update(merge_kwargs)

        # Materialize the layer
        raw = cls(**state)._construct_graph(deserializing=True)

        # Convert all keys to strings and dump tasks
        raw = {stringify(k): stringify_collection_keys(v) for k, v in raw.items()}
        keys = raw.keys() | dsk.keys()
        deps = {k: keys_in_tasks(keys, [v]) for k, v in raw.items()}

        return {"dsk": toolz.valmap(dumps_task, raw), "deps": deps}
예제 #18
0
def test_read_csv_compression(fmt, blocksize):
    if fmt and fmt not in compress:
        pytest.skip("compress function not provided for %s" % fmt)
    suffix = {"gzip": ".gz", "bz2": ".bz2", "zip": ".zip", "xz": ".xz"}.get(fmt, "")
    files2 = valmap(compress[fmt], csv_files) if fmt else csv_files
    renamed_files = {k + suffix: v for k, v in files2.items()}
    with filetexts(renamed_files, mode="b"):
        # This test is using `compression="infer"` (the default) for
        # read_csv.  The paths must have the appropriate extension.
        if fmt and blocksize:
            with pytest.warns(UserWarning):
                df = dd.read_csv("2014-01-*.csv" + suffix, blocksize=blocksize)
        else:
            df = dd.read_csv("2014-01-*.csv" + suffix, blocksize=blocksize)
        assert_eq(
            df.compute(scheduler="sync").reset_index(drop=True),
            expected.reset_index(drop=True),
            check_dtype=False,
        )
예제 #19
0
def test_warn_non_seekable_files():
    files2 = valmap(compress["gzip"], csv_files)
    with filetexts(files2, mode="b"):

        with pytest.warns(UserWarning) as w:
            df = dd.read_csv("2014-01-*.csv", compression="gzip")
            assert df.npartitions == 3

        assert len(w) == 1
        msg = str(w[0].message)
        assert "gzip" in msg
        assert "blocksize=None" in msg

        with pytest.warns(None) as w:
            df = dd.read_csv("2014-01-*.csv", compression="gzip", blocksize=None)
        assert len(w) == 0

        with pytest.raises(NotImplementedError):
            with pytest.warns(UserWarning):  # needed for pytest
                df = dd.read_csv("2014-01-*.csv", compression="foo")
예제 #20
0
def test_nanny_process_failure(c, s):
    n = yield Nanny(s.address, nthreads=2, loop=s.loop)
    first_dir = n.worker_dir

    assert os.path.exists(first_dir)

    original_address = n.worker_address
    ww = rpc(n.worker_address)
    yield ww.update_data(data=valmap(dumps, {"x": 1, "y": 2}))
    pid = n.pid
    assert pid is not None
    with ignoring(CommClosedError):
        yield c.run(os._exit, 0, workers=[n.worker_address])

    start = time()
    while n.pid == pid:  # wait while process dies and comes back
        yield gen.sleep(0.01)
        assert time() - start < 5

    start = time()
    yield gen.sleep(1)
    while not n.is_alive():  # wait while process comes back
        yield gen.sleep(0.01)
        assert time() - start < 5

    # assert n.worker_address != original_address  # most likely

    start = time()
    while n.worker_address not in s.nthreads or n.worker_dir is None:
        yield gen.sleep(0.01)
        assert time() - start < 5

    second_dir = n.worker_dir

    yield n.close()
    assert not os.path.exists(second_dir)
    assert not os.path.exists(first_dir)
    assert first_dir != n.worker_dir
    yield ww.close_rpc()
    s.stop()
예제 #21
0
파일: layers.py 프로젝트: nils-braun/dask
    def __dask_distributed_unpack__(cls, state, dsk, dependencies):
        from distributed.worker import dumps_task

        # msgpack will convert lists into tuples, here
        # we convert them back to lists
        if isinstance(state["column"], tuple):
            state["column"] = list(state["column"])
        if "inputs" in state:
            state["inputs"] = list(state["inputs"])

        # Materialize the layer
        layer_dsk = cls(**state)._construct_graph(deserializing=True)

        # Convert all keys to strings and dump tasks
        layer_dsk = {
            stringify(k): stringify_collection_keys(v) for k, v in layer_dsk.items()
        }
        keys = layer_dsk.keys() | dsk.keys()

        # TODO: use shuffle-knowledge to calculate dependencies more efficiently
        deps = {k: keys_in_tasks(keys, [v]) for k, v in layer_dsk.items()}

        return {"dsk": toolz.valmap(dumps_task, layer_dsk), "deps": deps}
예제 #22
0
def test_repeated_groupby():
    b = db.range(10, npartitions=4)
    c = b.groupby(lambda x: x % 3)
    assert valmap(len, dict(c)) == valmap(len, dict(c))
예제 #23
0
 def nunique(self) -> pd.Series:
     return pd.Series(tlz.valmap(len, self._unique()))
예제 #24
0
 def nunique(self) -> pd.Series:
     """Return a series of the number of unique values for each column in the catalog."""
     return pd.Series(tlz.valmap(len, self._unique()))
예제 #25
0
    def __dask_distributed_pack__(
        self,
        all_hlg_keys: Iterable[Hashable],
        known_key_dependencies: Mapping[Hashable, Set],
        client,
        client_keys: Iterable[Hashable],
    ) -> Any:
        """Pack the layer for scheduler communication in Distributed

        This method should pack its current state and is called by the Client when
        communicating with the Scheduler.
        The Scheduler will then use .__dask_distributed_unpack__(data, ...) to unpack
        the state, materialize the layer, and merge it into the global task graph.

        The returned state must be compatible with Distributed's scheduler, which
        means it must obey the following:
          - Serializable by msgpack (notice, msgpack converts lists to tuples)
          - All remote data must be unpacked (see unpack_remotedata())
          - All keys must be converted to strings now or when unpacking
          - All tasks must be serialized (see dumps_task())

        The default implementation materialize the layer thus layers such as Blockwise
        and ShuffleLayer should implement a specialized pack and unpack function in
        order to avoid materialization.

        Parameters
        ----------
        all_hlg_keys: Iterable[Hashable]
            All keys in the high level graph
        known_key_dependencies: Mapping[Hashable, Set]
            Already known dependencies
        client: distributed.Client
            The client calling this function.
        client_keys : Iterable[Hashable]
            List of keys requested by the client.

        Returns
        -------
        state: Object serializable by msgpack
            Scheduler compatible state of the layer
        """
        from distributed.client import Future
        from distributed.utils import CancelledError
        from distributed.utils_comm import subs_multiple, unpack_remotedata
        from distributed.worker import dumps_task

        dsk = dict(self)

        # Find aliases not in `client_keys` and substitute all matching keys
        # with its Future
        future_aliases = {
            k: v
            for k, v in dsk.items()
            if isinstance(v, Future) and k not in client_keys
        }
        if future_aliases:
            dsk = subs_multiple(dsk, future_aliases)

        # Remove `Future` objects from graph and note any future dependencies
        dsk2 = {}
        fut_deps = {}
        for k, v in dsk.items():
            dsk2[k], futs = unpack_remotedata(v, byte_keys=True)
            if futs:
                fut_deps[k] = futs
        dsk = dsk2

        # Check that any collected futures are valid
        unpacked_futures = set.union(*fut_deps.values()) if fut_deps else set()
        for future in unpacked_futures:
            if future.client is not client:
                raise ValueError(
                    "Inputs contain futures that were created by another client."
                )
            if stringify(future.key) not in client.futures:
                raise CancelledError(stringify(future.key))

        # Calculate dependencies without re-calculating already known dependencies
        # - Start with known dependencies
        dependencies = ensure_dict(known_key_dependencies, copy=True)
        # - Remove aliases for any tasks that depend on both an alias and a future.
        #   These can only be found in the known_key_dependencies cache, since
        #   any dependencies computed in this method would have already had the
        #   aliases removed.
        if future_aliases:
            alias_keys = set(future_aliases)
            dependencies = {k: v - alias_keys for k, v in dependencies.items()}
        # - Add in deps for any missing keys
        missing_keys = dsk.keys() - dependencies.keys()

        dependencies.update(
            (k, keys_in_tasks(all_hlg_keys, [dsk[k]], as_list=False))
            for k in missing_keys)
        # - Add in deps for any tasks that depend on futures
        for k, futures in fut_deps.items():
            if futures:
                d = ensure_set(dependencies[k], copy=True)
                d.update(f.key for f in futures)
                dependencies[k] = d

        # The scheduler expect all keys to be strings
        dependencies = {
            stringify(k): {stringify(dep)
                           for dep in deps}
            for k, deps in dependencies.items()
        }

        merged_hlg_keys = all_hlg_keys | dsk.keys()
        dsk = {
            stringify(k): stringify(v, exclusive=merged_hlg_keys)
            for k, v in dsk.items()
        }
        dsk = toolz.valmap(dumps_task, dsk)
        return {"dsk": dsk, "dependencies": dependencies}
예제 #26
0
    def __dask_distributed_pack__(
        self,
        all_hlg_keys: Iterable[Hashable],
        known_key_dependencies: Mapping[Hashable, set],
        client,
        client_keys: Iterable[Hashable],
    ) -> Any:
        """Pack the layer for scheduler communication in Distributed

        This method should pack its current state and is called by the Client when
        communicating with the Scheduler.
        The Scheduler will then use .__dask_distributed_unpack__(data, ...) to unpack
        the state, materialize the layer, and merge it into the global task graph.

        The returned state must be compatible with Distributed's scheduler, which
        means it must obey the following:
          - Serializable by msgpack (notice, msgpack converts lists to tuples)
          - All remote data must be unpacked (see unpack_remotedata())
          - All keys must be converted to strings now or when unpacking
          - All tasks must be serialized (see dumps_task())

        The default implementation materialize the layer thus layers such as Blockwise
        and ShuffleLayer should implement a specialized pack and unpack function in
        order to avoid materialization.

        Parameters
        ----------
        all_hlg_keys: Iterable[Hashable]
            All keys in the high level graph
        known_key_dependencies: Mapping[Hashable, set]
            Already known dependencies
        client: distributed.Client
            The client calling this function.
        client_keys : Iterable[Hashable]
            List of keys requested by the client.

        Returns
        -------
        state: Object serializable by msgpack
            Scheduler compatible state of the layer
        """
        from distributed.client import Future
        from distributed.utils import CancelledError
        from distributed.utils_comm import subs_multiple, unpack_remotedata
        from distributed.worker import dumps_task

        dsk = dict(self)

        # Find aliases not in `client_keys` and substitute all matching keys
        # with its Future
        values = {
            k: v
            for k, v in dsk.items()
            if isinstance(v, Future) and k not in client_keys
        }
        if values:
            dsk = subs_multiple(dsk, values)

        # Unpack remote data and record its dependencies
        dsk = {k: unpack_remotedata(v, byte_keys=True) for k, v in dsk.items()}
        unpacked_futures = set.union(*[v[1] for v in dsk.values()]) if dsk else set()
        for future in unpacked_futures:
            if future.client is not client:
                raise ValueError(
                    "Inputs contain futures that were created by another client."
                )
            if stringify(future.key) not in client.futures:
                raise CancelledError(stringify(future.key))
        unpacked_futures_deps = {}
        for k, v in dsk.items():
            if len(v[1]):
                unpacked_futures_deps[k] = {f.key for f in v[1]}
        dsk = {k: v[0] for k, v in dsk.items()}

        # Calculate dependencies without re-calculating already known dependencies
        missing_keys = dsk.keys() - known_key_dependencies.keys()
        dependencies = {
            k: keys_in_tasks(all_hlg_keys, [dsk[k]], as_list=False)
            for k in missing_keys
        }
        for k, v in unpacked_futures_deps.items():
            dependencies[k] = set(dependencies.get(k, ())) | v
        dependencies.update(known_key_dependencies)

        # The scheduler expect all keys to be strings
        dependencies = {
            stringify(k): {stringify(dep) for dep in deps}
            for k, deps in dependencies.items()
        }

        merged_hlg_keys = all_hlg_keys | dsk.keys()
        dsk = {
            stringify(k): stringify(v, exclusive=merged_hlg_keys)
            for k, v in dsk.items()
        }
        dsk = toolz.valmap(dumps_task, dsk)
        return {"dsk": dsk, "dependencies": dependencies}
예제 #27
0
def main(scheduler, host, worker_port, listen_address, contact_address,
         nanny_port, nthreads, nprocs, nanny, name, pid_file, resources,
         dashboard, bokeh, bokeh_port, scheduler_file, dashboard_prefix,
         tls_ca_file, tls_cert, tls_key, dashboard_address, worker_class,
         preload_nanny, **kwargs):
    g0, g1, g2 = gc.get_threshold(
    )  # https://github.com/dask/distributed/issues/1653
    gc.set_threshold(g0 * 3, g1 * 3, g2 * 3)

    enable_proctitle_on_current()
    enable_proctitle_on_children()

    if bokeh_port is not None:
        warnings.warn(
            "The --bokeh-port flag has been renamed to --dashboard-address. "
            "Consider adding ``--dashboard-address :%d`` " % bokeh_port)
        dashboard_address = bokeh_port
    if bokeh is not None:
        warnings.warn(
            "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. "
        )
        dashboard = bokeh

    sec = {
        k: v
        for k, v in [
            ("tls_ca_file", tls_ca_file),
            ("tls_worker_cert", tls_cert),
            ("tls_worker_key", tls_key),
        ] if v is not None
    }

    if nprocs < 0:
        nprocs = CPU_COUNT + 1 + nprocs

    if nprocs <= 0:
        logger.error(
            "Failed to launch worker. Must specify --nprocs so that there's at least one process."
        )
        sys.exit(1)

    if nprocs > 1 and not nanny:
        logger.error(
            "Failed to launch worker.  You cannot use the --no-nanny argument when nprocs > 1."
        )
        sys.exit(1)

    if contact_address and not listen_address:
        logger.error(
            "Failed to launch worker. "
            "Must specify --listen-address when --contact-address is given")
        sys.exit(1)

    if nprocs > 1 and listen_address:
        logger.error("Failed to launch worker. "
                     "You cannot specify --listen-address when nprocs > 1.")
        sys.exit(1)

    if (worker_port or host) and listen_address:
        logger.error(
            "Failed to launch worker. "
            "You cannot specify --listen-address when --worker-port or --host is given."
        )
        sys.exit(1)

    try:
        if listen_address:
            (host, worker_port) = get_address_host_port(listen_address,
                                                        strict=True)

        if contact_address:
            # we only need this to verify it is getting parsed
            (_, _) = get_address_host_port(contact_address, strict=True)
        else:
            # if contact address is not present we use the listen_address for contact
            contact_address = listen_address
    except ValueError as e:
        logger.error("Failed to launch worker. " + str(e))
        sys.exit(1)

    if nanny:
        port = nanny_port
    else:
        port = worker_port

    if not nthreads:
        nthreads = CPU_COUNT // nprocs

    if pid_file:
        with open(pid_file, "w") as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    if resources:
        resources = resources.replace(",", " ").split()
        resources = dict(pair.split("=") for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    worker_class = import_term(worker_class)
    if nanny:
        kwargs["worker_class"] = worker_class
        kwargs["preload_nanny"] = preload_nanny

    if nanny:
        kwargs.update({
            "worker_port": worker_port,
            "listen_address": listen_address
        })
        t = Nanny
    else:
        if nanny_port:
            kwargs["service_ports"] = {"nanny": nanny_port}
        t = worker_class

    if (not scheduler and not scheduler_file
            and dask.config.get("scheduler-address", None) is None):
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    with suppress(TypeError, ValueError):
        name = int(name)

    if "DASK_INTERNAL_INHERIT_CONFIG" in os.environ:
        config = deserialize_for_cli(
            os.environ["DASK_INTERNAL_INHERIT_CONFIG"])
        # Update the global config given priority to the existing global config
        dask.config.update(dask.config.global_config, config, priority="old")

    nannies = [
        t(scheduler,
          scheduler_file=scheduler_file,
          nthreads=nthreads,
          loop=loop,
          resources=resources,
          security=sec,
          contact_address=contact_address,
          host=host,
          port=port,
          dashboard=dashboard,
          dashboard_address=dashboard_address,
          name=name if nprocs == 1 or name is None or name == "" else
          str(name) + "-" + str(i),
          **kwargs) for i in range(nprocs)
    ]

    async def close_all():
        # Unregister all workers from scheduler
        if nanny:
            await asyncio.gather(*[n.close(timeout=2) for n in nannies])

    signal_fired = False

    def on_signal(signum):
        nonlocal signal_fired
        signal_fired = True
        if signum != signal.SIGINT:
            logger.info("Exiting on signal %d", signum)
        return asyncio.ensure_future(close_all())

    async def run():
        await asyncio.gather(*nannies)
        await asyncio.gather(*[n.finished() for n in nannies])

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except TimeoutError:
        # We already log the exception in nanny / worker. Don't do it again.
        if not signal_fired:
            logger.info("Timed out starting worker")
        sys.exit(1)
    except KeyboardInterrupt:
        pass
    finally:
        logger.info("End worker")