예제 #1
0
def test_unpack_collections():
    a = delayed(1) + 5
    b = a + 1
    c = a + 2

    def build(a, b, c, iterator):
        t = (
            a,
            b,  # Top-level collections
            {
                "a": a,  # dict
                a: b,  # collections as keys
                "b": [1, 2, [b]],  # list
                "c": 10,  # other builtins pass through unchanged
                "d": (c, 2),  # tuple
                "e": {a, 2, 3},  # set
                "f": OrderedDict([("a", a)]),
            },  # OrderedDict
            iterator,
        )  # Iterator

        if dataclasses is not None:
            t[2]["f"] = ADataClass(a=a)
            t[2]["g"] = (ADataClass, a)

        return t

    args = build(a, b, c, (i for i in [a, b, c]))

    collections, repack = unpack_collections(*args)
    assert len(collections) == 3

    # Replace collections with `'~a'` strings
    result = repack(["~a", "~b", "~c"])
    sol = build("~a", "~b", "~c", ["~a", "~b", "~c"])
    assert result == sol

    # traverse=False
    collections, repack = unpack_collections(*args, traverse=False)
    assert len(collections) == 2  # just a and b
    assert repack(collections) == args

    # No collections
    collections, repack = unpack_collections(1, 2, {"a": 3})
    assert not collections
    assert repack(collections) == (1, 2, {"a": 3})

    # Result that looks like a task
    def fail(*args):
        raise ValueError("Shouldn't have been called")  # pragma: nocover

    collections, repack = unpack_collections(
        a, (fail, 1), [(fail, 2, 3)], traverse=False
    )
    repack(collections)  # Smoketest task literals
    repack([(fail, 1)])  # Smoketest results that look like tasks
예제 #2
0
파일: test_base.py 프로젝트: yooerzf/dask
def test_unpack_collections():
    a = delayed(1) + 5
    b = a + 1
    c = a + 2

    def build(a, b, c, iterator):
        t = (
            a,
            b,  # Top-level collections
            {
                'a': a,  # dict
                a: b,  # collections as keys
                'b': [1, 2, [b]],  # list
                'c': 10,  # other builtins pass through unchanged
                'd': (c, 2),  # tuple
                'e': {a, 2, 3},  # set
                'f': OrderedDict([('a', a)])
            },  # OrderedDict
            iterator)  # Iterator

        if dataclasses is not None:
            t[2]['f'] = ADataClass(a=a)

        return t

    args = build(a, b, c, (i for i in [a, b, c]))

    collections, repack = unpack_collections(*args)
    assert len(collections) == 3

    # Replace collections with `'~a'` strings
    result = repack(['~a', '~b', '~c'])
    sol = build('~a', '~b', '~c', ['~a', '~b', '~c'])
    assert result == sol

    # traverse=False
    collections, repack = unpack_collections(*args, traverse=False)
    assert len(collections) == 2  # just a and b
    assert repack(collections) == args

    # No collections
    collections, repack = unpack_collections(1, 2, {'a': 3})
    assert not collections
    assert repack(collections) == (1, 2, {'a': 3})

    # Result that looks like a task
    def fail(*args):
        raise ValueError("Shouldn't have been called")

    collections, repack = unpack_collections(a, (fail, 1), [(fail, 2, 3)],
                                             traverse=False)
    repack(collections)  # Smoketest task literals
    repack([(fail, 1)])  # Smoketest results that look like tasks
예제 #3
0
def test_unpack_collections():
    a = delayed(1) + 5
    b = a + 1
    c = a + 2

    def build(a, b, c, iterator):
        t = (a, b,                  # Top-level collections
                {'a': a,            # dict
                 a: b,              # collections as keys
                 'b': [1, 2, [b]],  # list
                 'c': 10,           # other builtins pass through unchanged
                 'd': (c, 2),       # tuple
                 'e': {a, 2, 3}},   # set
                iterator)           # Iterator

        if dataclasses is not None:
            t[2]['f'] = ADataClass(a=a)

        return t

    args = build(a, b, c, (i for i in [a, b, c]))

    collections, repack = unpack_collections(*args)
    assert len(collections) == 3

    # Replace collections with `'~a'` strings
    result = repack(['~a', '~b', '~c'])
    sol = build('~a', '~b', '~c', ['~a', '~b', '~c'])
    assert result == sol

    # traverse=False
    collections, repack = unpack_collections(*args, traverse=False)
    assert len(collections) == 2  # just a and b
    assert repack(collections) == args

    # No collections
    collections, repack = unpack_collections(1, 2, {'a': 3})
    assert not collections
    assert repack(collections) == (1, 2, {'a': 3})

    # Result that looks like a task
    def fail(*args):
        raise ValueError("Shouldn't have been called")

    collections, repack = unpack_collections(a, (fail, 1), [(fail, 2, 3)],
                                             traverse=False)
    repack(collections)  # Smoketest task literals
    repack([(fail, 1)])  # Smoketest results that look like tasks
예제 #4
0
def checkpoint(
    *collections,
    split_every: float | Literal[False] | None = None,
) -> Delayed:
    """Build a :doc:`delayed` which waits until all chunks of the input collection(s)
    have been computed before returning None.

    Parameters
    ----------
    collections
        Zero or more Dask collections or nested data structures containing zero or more
        collections
    split_every: int >= 2 or False, optional
        Determines the depth of the recursive aggregation. If greater than the number of
        input keys, the aggregation will be performed in multiple steps; the depth of
        the aggregation graph will be :math:`log_{split_every}(input keys)`. Setting to
        a low value can reduce cache size and network transfers, at the cost of more CPU
        and a larger dask graph.

        Set to False to disable. Defaults to 8.

    Returns
    -------
    :doc:`delayed` yielding None
    """
    if split_every is None:
        # FIXME https://github.com/python/typeshed/issues/5074
        split_every = 8  # type: ignore
    elif split_every is not False:
        split_every = int(split_every)  # type: ignore
        if split_every < 2:  # type: ignore
            raise ValueError("split_every must be False, None, or >= 2")

    collections, _ = unpack_collections(*collections)
    if len(collections) == 1:
        return _checkpoint_one(collections[0], split_every)
    else:
        return delayed(chunks.checkpoint)(*(_checkpoint_one(c, split_every)
                                            for c in collections))
예제 #5
0
def wait_on(
    *collections,
    split_every: float | Literal[False] | None = None,
):
    """Ensure that all chunks of all input collections have been computed before
    computing the dependents of any of the chunks.

    The following example creates a dask array ``u`` that, when used in a computation,
    will only proceed when all chunks of the array ``x`` have been computed, but
    otherwise matches ``x``:

    >>> import dask.array as da
    >>> x = da.ones(10, chunks=5)
    >>> u = wait_on(x)

    The following example will create two arrays ``u`` and ``v`` that, when used in a
    computation, will only proceed when all chunks of the arrays ``x`` and ``y`` have
    been computed but otherwise match ``x`` and ``y``:

    >>> x = da.ones(10, chunks=5)
    >>> y = da.zeros(10, chunks=5)
    >>> u, v = wait_on(x, y)

    Parameters
    ----------
    collections
        Zero or more Dask collections or nested structures of Dask collections
    split_every
        See :func:`checkpoint`

    Returns
    -------
    Same as ``collections``
        Dask collection of the same type as the input, which computes to the same value,
        or a nested structure equivalent to the input where the original collections
        have been replaced.
    """
    blocker = checkpoint(*collections, split_every=split_every)

    def block_one(coll):
        tok = tokenize(coll, blocker)
        dsks = []
        rename = {}
        for prev_name in get_collection_names(coll):
            new_name = "wait_on-" + tokenize(prev_name, tok)
            rename[prev_name] = new_name
            layer = _build_map_layer(chunks.bind,
                                     prev_name,
                                     new_name,
                                     coll,
                                     dependencies=(blocker, ))
            dsks.append(
                HighLevelGraph.from_collections(new_name,
                                                layer,
                                                dependencies=(coll, blocker)))
        dsk = HighLevelGraph.merge(*dsks)
        rebuild, args = coll.__dask_postpersist__()
        return rebuild(dsk, *args, rename=rename)

    unpacked, repack = unpack_collections(*collections)
    out = repack([block_one(coll) for coll in unpacked])
    return out[0] if len(collections) == 1 else out
예제 #6
0
def bind(
    children: T,
    parents,
    *,
    omit=None,
    seed: Hashable = None,
    assume_layers: bool = True,
    split_every: float | Literal[False] | None = None,
) -> T:
    """
    Make ``children`` collection(s), optionally omitting sub-collections, dependent on
    ``parents`` collection(s). Two examples follow.

    The first example creates an array ``b2`` whose computation first computes an array
    ``a`` completely and then computes ``b`` completely, recomputing ``a`` in the
    process:

    >>> import dask
    >>> import dask.array as da
    >>> a = da.ones(4, chunks=2)
    >>> b = a + 1
    >>> b2 = bind(b, a)
    >>> len(b2.dask)
    9
    >>> b2.compute()
    array([2., 2., 2., 2.])

    The second example creates arrays ``b3`` and ``c3``, whose computation first
    computes an array ``a`` and then computes the additions, this time not
    recomputing ``a`` in the process:

    >>> c = a + 2
    >>> b3, c3 = bind((b, c), a, omit=a)
    >>> len(b3.dask), len(c3.dask)
    (7, 7)
    >>> dask.compute(b3, c3)
    (array([2., 2., 2., 2.]), array([3., 3., 3., 3.]))

    Parameters
    ----------
    children
        Dask collection or nested structure of Dask collections
    parents
        Dask collection or nested structure of Dask collections
    omit
        Dask collection or nested structure of Dask collections
    seed
        Hashable used to seed the key regeneration. Omit to default to a random number
        that will produce different keys at every call.
    assume_layers
        True
            Use a fast algorithm that works at layer level, which assumes that all
            collections in ``children`` and ``omit``

            #. use :class:`~dask.highlevelgraph.HighLevelGraph`,
            #. define the ``__dask_layers__()`` method, and
            #. never had their graphs squashed and rebuilt between the creation of the
               ``omit`` collections and the ``children`` collections; in other words if
               the keys of the ``omit`` collections can be found among the keys of the
               ``children`` collections, then the same must also hold true for the
               layers.
        False
            Use a slower algorithm that works at keys level, which makes none of the
            above assumptions.
    split_every
        See :func:`checkpoint`

    Returns
    -------
    Same as ``children``
        Dask collection or structure of dask collection equivalent to ``children``,
        which compute to the same values. All keys of ``children`` will be regenerated,
        up to and excluding the keys of ``omit``. Nodes immediately above ``omit``, or
        the leaf nodes if the collections in ``omit`` are not found, are prevented from
        computing until all collections in ``parents`` have been fully computed.
    """
    if seed is None:
        seed = uuid.uuid4().bytes

    # parents=None is a special case invoked by the one-liner wrapper clone() below
    blocker = (checkpoint(parents, split_every=split_every)
               if parents is not None else None)

    omit, _ = unpack_collections(omit)
    if assume_layers:
        # Set of all the top-level layers of the collections in omit
        omit_layers = {
            layer
            for coll in omit for layer in coll.__dask_layers__()
        }
        omit_keys = set()
    else:
        omit_layers = set()
        # Set of *all* the keys, not just the top-level ones, of the collections in omit
        omit_keys = {key for coll in omit for key in coll.__dask_graph__()}

    unpacked_children, repack = unpack_collections(children)
    return repack([
        _bind_one(child, blocker, omit_layers, omit_keys, seed)
        for child in unpacked_children
    ])[0]
예제 #7
0
def _get_dsk(node):
    d = node.todelayed()
    collections, repack = unpack_collections(d, traverse=False)
    return collections_to_dsk(collections, True)