示例#1
0
    def fit(self, X, y=None):
        """Fit a gradient boosting classifier

        Parameters
        ----------
        X : array-like [n_samples, n_features]
            Feature Matrix. May be a dask.array or dask.dataframe
        y : array-like
            Labels

        Returns
        -------
        self : XGBClassifier

        Notes
        -----
        This differs from the XGBoost version in three ways

        1. The ``sample_weight``, ``eval_set``, ``eval_metric``,
          ``early_stopping_rounds`` and ``verbose`` fit kwargs are not
          supported.
        2. The labels are not automatically label-encoded
        3. The ``classes_`` and ``n_classes_`` attributes are not learned
        """
        client = default_client()
        xgb_options = self.get_xgb_params()
        self._Booster = train(client,
                              xgb_options,
                              X,
                              y,
                              num_boost_round=self.n_estimators)
        return self
示例#2
0
文件: dask.py 项目: snth/streamz
    def __init__(self, child, n, loop=None):
        client = default_client()
        self.queue = dask.distributed.Queue(maxsize=n, client=client)

        core.Stream.__init__(self, child, loop=loop or client.loop)

        self.loop.add_callback(self.cb)
示例#3
0
    def fit(self, X, y=None):
        """Fit the gradient boosting model

        Parameters
        ----------
        X : array-like [n_samples, n_features]
        y : array-like

        Returns
        -------
        self : the fitted Regressor

        Notes
        -----
        This differs from the XGBoost version not supporting the ``eval_set``,
        ``eval_metric``, ``early_stopping_rounds`` and ``verbose`` fit
        kwargs.
        """
        client = default_client()
        xgb_options = self.get_xgb_params()
        self._Booster = train(client,
                              xgb_options,
                              X,
                              y,
                              num_boost_round=self.n_estimators)
        return self
示例#4
0
def test_from_partitions(axis, index, columns, row_lengths, column_widths):
    num_rows = 2**16
    num_cols = 2**8
    data = np.random.randint(0, 100, size=(num_rows, num_cols))
    df1, df2 = pandas.DataFrame(data), pandas.DataFrame(data)
    expected_df = pandas.concat([df1, df2], axis=1 if axis is None else axis)

    index = expected_df.index if index == "index" else None
    columns = expected_df.columns if columns == "columns" else None
    row_lengths = (None if row_lengths is None else
                   [num_rows, num_rows] if axis == 0 else [num_rows])
    column_widths = (None if column_widths is None else
                     [num_cols] if axis == 0 else [num_cols, num_cols])

    if Engine.get() == "Ray":
        if axis is None:
            futures = [[ray.put(df1), ray.put(df2)]]
        else:
            futures = [ray.put(df1), ray.put(df2)]
    if Engine.get() == "Dask":
        client = default_client()
        if axis is None:
            futures = [client.scatter([df1, df2], hash=False)]
        else:
            futures = client.scatter([df1, df2], hash=False)
    actual_df = from_partitions(
        futures,
        axis,
        index=index,
        columns=columns,
        row_lengths=row_lengths,
        column_widths=column_widths,
    )
    df_equals(expected_df, actual_df)
示例#5
0
    def deploy(cls, func, *args, num_returns=1, pure=None, **kwargs):
        """
        Deploy a function in a worker process.

        Parameters
        ----------
        func : callable
            Function to be deployed in a worker process.
        *args : list
            Additional positional arguments to be passed in `func`.
        num_returns : int, default: 1
            The number of returned objects.
        pure : bool, optional
            Whether or not `func` is pure. See `Client.submit` for details.
        **kwargs : dict
            Additional keyword arguments to be passed in ``func``.

        Returns
        -------
        list
            The result of ``func`` splitted into parts in accordance with ``num_returns``.
        """
        client = default_client()
        remote_task_future = client.submit(func, *args, pure=pure, **kwargs)
        if num_returns != 1:
            return [
                client.submit(lambda l, i: l[i], remote_task_future, i)
                for i in range(num_returns)
            ]
        return remote_task_future
示例#6
0
def wait(collections, client=None, return_when=FIRST_COMPLETED):
    """Calculates collections on client.

    Parameters
    ----------
    collections : list of dask tasks
    client : None or distributed.client.Client
        if None uses distributed.client.default_client()
    return_when : int
        if == FIRST_COMPLETED returns when first task completes
        if == ALL_COMPLETED returns when all tasks completed
        Currently supports only FIRST_COMPLETED.

    Returns
    -------
    tuple : (result, index, unfinished_futures)
    """
    if return_when not in (FIRST_COMPLETED, ALL_COMPLETED):
        raise ValueError(
            "Unknown value for 'return_when'." +
            "Expected {} or {}.".format(FIRST_COMPLETED, ALL_COMPLETED) +
            "Received {}.".format(return_when))
    if return_when == ALL_COMPLETED:
        raise NotImplementedError("Support for ALL_COMPLETED not implemented.")
    client = client or dc.default_client()
    futures = client.compute(collections)
    f = dc.as_completed(futures).__next__()
    i = futures.index(f)
    del futures[i]
    res = f.result()
    del f
    return res, i, futures
示例#7
0
    def mask(self, row_labels, col_labels):
        """
        Lazily create a mask that extracts the indices provided.

        Parameters
        ----------
        row_labels : list-like, slice or label
            The row labels for the rows to extract.
        col_labels : list-like, slice or label
            The column labels for the columns to extract.

        Returns
        -------
        PandasOnDaskDataframePartition
            A new ``PandasOnDaskDataframePartition`` object.
        """
        new_obj = super().mask(row_labels, col_labels)
        client = default_client()
        if isinstance(row_labels, slice) and isinstance(self._length_cache, Future):
            new_obj._length_cache = client.submit(
                compute_sliced_len, row_labels, self._length_cache
            )
        if isinstance(col_labels, slice) and isinstance(self._width_cache, Future):
            new_obj._width_cache = client.submit(
                compute_sliced_len, col_labels, self._width_cache
            )
        return new_obj
示例#8
0
文件: utils.py 项目: prutskov/modin
def initialize_dask():
    """Initialize Dask environment."""
    from distributed.client import default_client

    try:
        client = default_client()
    except ValueError:
        from distributed import Client

        # The indentation here is intentional, we want the code to be indented.
        ErrorMessage.not_initialized(
            "Dask",
            """
    from distributed import Client

    client = Client()
""",
        )
        num_cpus = CpuCount.get()
        memory_limit = Memory.get()
        worker_memory_limit = memory_limit // num_cpus if memory_limit else "auto"
        client = Client(n_workers=num_cpus, memory_limit=worker_memory_limit)

    num_cpus = len(client.ncores())
    NPartitions._put(num_cpus)
示例#9
0
文件: dask.py 项目: kszucs/streams
    def __init__(self, child, limit=10, client=None):
        self.client = client or default_client()
        self.queue = Queue(maxsize=limit)
        self.condition = Condition()

        Stream.__init__(self, child)

        self.client.loop.add_callback(self.cb)
示例#10
0
 def predict(self, X):
     client = default_client()
     class_probs = predict(client, self._Booster, X)
     if class_probs.ndim > 1:
         cidx = da.argmax(class_probs, axis=1)
     else:
         cidx = (class_probs > 0).astype(np.int64)
     return cidx
示例#11
0
def from_kafka_batched(topic,
                       consumer_params,
                       poll_interval="1s",
                       npartitions=1,
                       start=False,
                       dask=False,
                       **kwargs):
    """ Get messages from Kafka in batches

    Uses the confluent-kafka library,
    https://docs.confluent.io/current/clients/confluent-kafka-python/

    This source will emit lists of messages for each partition of a single given
    topic per time interval, if there is new data. If using dask, one future
    will be produced per partition per time-step, if there is data.

    Parameters
    ----------
    topic: str
        Kafka topic to consume from
    consumer_params: dict
        Settings to set up the stream, see
        https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration
        https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md
        Examples:
        bootstrap.servers: Connection string(s) (host:port) by which to reach Kafka
        group.id: Identity of the consumer. If multiple sources share the same
            group, each message will be passed to only one of them.
    poll_interval: number
        Seconds that elapse between polling Kafka for new messages
    npartitions: int
        Number of partitions in the topic
    start: bool (False)
        Whether to start polling upon instantiation

    Example
    -------

    >>> source = Stream.from_kafka_batched('mytopic',
    ...           {'bootstrap.servers': 'localhost:9092',
    ...            'group.id': 'rapidz'}, npartitions=4)  # doctest: +SKIP
    """
    if dask:
        from distributed.client import default_client

        kwargs["loop"] = default_client().loop
    source = FromKafkaBatched(topic,
                              consumer_params,
                              poll_interval=poll_interval,
                              npartitions=npartitions,
                              **kwargs)
    if dask:
        source = source.scatter()

    if start:
        source.start()

    return source.starmap(get_message_batch)
示例#12
0
    def update(self, x, who=None, metadata=None):
        client = default_client()

        self._retain_refs(metadata)
        future = yield client.scatter(x, asynchronous=True)
        f = yield self._emit(future, metadata=metadata)
        self._release_refs(metadata)

        raise gen.Return(f)
示例#13
0
    def update(self, x, who=None, metadata=None):
        client = default_client()

        self._retain_refs(metadata)
        result = yield client.gather(x, asynchronous=True)
        result2 = yield self._emit(result, metadata=metadata)
        self._release_refs(metadata)

        raise gen.Return(result2)
示例#14
0
 def update(self, x, emit_id=None, who=None):
     client = default_client()
     result = client.submit(
         self.func,
         x,
         *self.args,
         **self.kwargs,
         pure=False,
         key=f'{self.func.__name__}--{emit_id}--{str(uuid.uuid4())}')
     return self._emit(result, emit_id=emit_id)
示例#15
0
 def update(self, x, emit_id=None, who=None):
     try:
         client = default_client()
         future_as_list = yield client.scatter([x],
                                               asynchronous=True,
                                               hash=False)
         future = future_as_list[0]
         f = yield self._emit(future, emit_id=emit_id)
         raise gen.Return(f)
     except Exception as e:
         self.error = True
         raise
示例#16
0
    def update(self, x, sleep_between_gather=0.1, emit_id=None, who=None):
        try:
            client = default_client()
            result = yield client.gather(x, asynchronous=True)
            result2 = yield self._emit(result, emit_id=emit_id)
            raise gen.Return(result2)

            # result1 = yield get_stream(self).wait(sleep_between_gather=sleep_between_gather)
            # raise gen.Return(result1)

        except Exception as e:
            self.error = True
            raise
示例#17
0
 def update(self, x, who=None):
     if self.state is core.no_default:
         self.state = x
         return self._emit(self.state)
     else:
         client = default_client()
         result = client.submit(self.func, self.state, x, **self.kwargs)
         if self.returns_state:
             state = client.submit(getitem, result, 0)
             result = client.submit(getitem, result, 1)
         else:
             state = result
         self.state = state
         return self._emit(result)
示例#18
0
def from_kafka_batched_scatter(
    topics: List[str],
    consumer_params: Dict,
    poll_interval: int = 5,
    batch_size: int = 1000,
    dask: bool = False,
    db=None,
    **kwargs,
):
    """
    Parameters
    ----------
    topics: List[str]
        Labels of Kafka topics to consume from.
    consumer_params: Dict
        Settings to set up the stream, see
        https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration
        https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md
        Examples:
        bootstrap.servers, Connection string(s) (host:port) by which to reach
        Kafka;
        group.id, Identity of the consumer. If multiple sources share the same
        group, each message will be passed to only one of them.
    batch_size: int, optional (default=1000)
        batch size of polling.
    poll_interval: float, optional (default=5.0)
        Seconds that elapse between polling Kafka for new messages.
    dask: bool, optional (default=False)
        If True, will poll events from each partitions distributed among Dask workers.
    db: Callable, optional (default=None)
        persistent layer to check kafka offset to provide once-semantics.
        If None, will initiate waterhealer.db.expiringdict.Database.
    """
    if dask:
        from distributed.client import default_client

        kwargs['loop'] = default_client().loop
    source = FromKafkaBatched(
        topics=topics,
        consumer_params=consumer_params,
        poll_interval=poll_interval,
        batch_size=batch_size,
        db=db,
        **kwargs,
    )

    if dask:
        source = source.scatter()

    return source.starmap(get_message_batch)
示例#19
0
    def deploy_func_between_two_axis_partitions(cls, axis, func, num_splits,
                                                len_of_left, other_shape,
                                                kwargs, *partitions):
        """
        Deploy a function along a full axis between two data sets.

        Parameters
        ----------
        axis : {0, 1}
            The axis to perform the function along.
        func : callable
            The function to perform.
        num_splits : int
            The number of splits to return (see `split_result_of_axis_func_pandas`).
        len_of_left : int
            The number of values in `partitions` that belong to the left data set.
        other_shape : np.ndarray
            The shape of right frame in terms of partitions, i.e.
            (other_shape[i-1], other_shape[i]) will indicate slice to restore i-1 axis partition.
        kwargs : dict
            Additional keywords arguments to be passed in `func`.
        *partitions : iterable
            All partitions that make up the full axis (row or column) for both data sets.

        Returns
        -------
        list
            A list of distributed.Future.
        """
        client = default_client()
        axis_result = client.submit(
            deploy_dask_func,
            PandasDataframeAxisPartition.
            deploy_func_between_two_axis_partitions,
            axis,
            func,
            num_splits,
            len_of_left,
            other_shape,
            kwargs,
            *partitions,
            pure=False,
        )
        # We have to do this to split it back up. It is already split, but we need to
        # get futures for each.
        return [
            client.submit(lambda l: l[i], axis_result, pure=False)
            for i in range(num_splits * 4)
        ]
示例#20
0
    def preprocess_func(cls, func):
        """
        Preprocess a function before an ``apply`` call.

        Parameters
        ----------
        func : callable
            The function to preprocess.

        Returns
        -------
        callable
            An object that can be accepted by ``apply``.
        """
        return default_client().scatter(func, hash=False, broadcast=True)
示例#21
0
文件: dask.py 项目: wwoods/streamz
    def update(self, x, who=None, metadata=None):
        client = default_client()

        self._retain_refs(metadata)
        # We need to make sure that x is treated as it is by dask
        # However, client.scatter works internally different for
        # lists and dicts. So we always use a list here to be sure
        # we know the format exactly. We do not use a key to avoid
        # issues like https://github.com/python-streamz/streams/issues/397.
        future_as_list = yield client.scatter([x], asynchronous=True, hash=False)
        future = future_as_list[0]
        f = yield self._emit(future, metadata=metadata)
        self._release_refs(metadata)

        raise gen.Return(f)
示例#22
0
    def get_client(return_exception=False):
        error = 'no error'
        try:
            from distributed.client import default_client

            client = default_client()
        except Exception as e:
            error = str(e)
            logger.error(e)
            client = None

        if return_exception:
            return client, error
        else:
            return client
示例#23
0
    def deploy_axis_func(cls, axis, func, num_splits, kwargs,
                         maintain_partitioning, *partitions):
        """
        Deploy a function along a full axis.

        Parameters
        ----------
        axis : {0, 1}
            The axis to perform the function along.
        func : callable
            The function to perform.
        num_splits : int
            The number of splits to return (see `split_result_of_axis_func_pandas`).
        kwargs : dict
            Additional keywords arguments to be passed in `func`.
        maintain_partitioning : bool
            If True, keep the old partitioning if possible.
            If False, create a new partition layout.
        *partitions : iterable
            All partitions that make up the full axis (row or column).

        Returns
        -------
        list
            A list of distributed.Future.
        """
        client = default_client()
        axis_result = client.submit(
            deploy_dask_func,
            PandasDataframeAxisPartition.deploy_axis_func,
            axis,
            func,
            num_splits,
            kwargs,
            maintain_partitioning,
            *partitions,
            pure=False,
        )

        lengths = kwargs.get("_lengths", None)
        result_num_splits = len(lengths) if lengths else num_splits

        # We have to do this to split it back up. It is already split, but we need to
        # get futures for each.
        return [
            client.submit(lambda l: l[i], axis_result, pure=False)
            for i in range(result_num_splits * 4)
        ]
示例#24
0
    def _column_widths(self):
        """
        Compute the column partitions widths if they are not cached.

        Returns
        -------
        list
            A list of column partitions widths.
        """
        client = default_client()
        if self._column_widths_cache is None:
            self._column_widths_cache = client.gather([
                obj.apply(lambda df: len(df.columns)).future
                for obj in self._partitions[0]
            ])
        return self._column_widths_cache
示例#25
0
    def materialize(cls, future):
        """
        Materialize data matching `future` object.

        Parameters
        ----------
        future : distributed.Future or list
            Future object of list of future objects whereby data needs to be materialized.

        Returns
        -------
        Any
            An object(s) from the distributed memory.
        """
        client = default_client()
        return client.gather(future)
示例#26
0
def _get_files_from_location(input_item, file_format: str = None, **kwargs):
    if file_format == "memory":
        client = default_client()
        return client.get_dataset(input_item, **kwargs)

    if not file_format:
        _, extension = os.path.splitext(input_item)

        file_format = extension.lstrip(".")

    try:
        read_function = getattr(dd, f"read_{file_format}")
    except AttributeError:
        raise AttributeError(f"Can not read files of format {file_format}")

    return read_function(input_item, **kwargs)
示例#27
0
    def _row_lengths(self):
        """
        Compute the row partitions lengths if they are not cached.

        Returns
        -------
        list
            A list of row partitions lengths.
        """
        client = default_client()
        if self._row_lengths_cache is None:
            self._row_lengths_cache = client.gather([
                obj.apply(lambda df: len(df)).future
                for obj in self._partitions.T[0]
            ])
        return self._row_lengths_cache
示例#28
0
文件: dask.py 项目: salah93/streamz
    def update(self, x, who=None, metadata=None):
        client = default_client()

        self._retain_refs(metadata)
        # We need to make sure that x is treated as it is by dask
        # However, client.scatter works internally different for
        # lists and dicts. So we always use a dict here to be sure
        # we know the format exactly. The key will be taken as the
        # dask identifier of the data.
        tokenized_x = f"{type(x).__name__}-{tokenize(x)}"
        future_as_dict = yield client.scatter({tokenized_x: x}, asynchronous=True)
        future = future_as_dict[tokenized_x]
        f = yield self._emit(future, metadata=metadata)
        self._release_refs(metadata)

        raise gen.Return(f)
示例#29
0
    def put(cls, obj):
        """
        Put an object into distributed memory and wrap it with partition object.

        Parameters
        ----------
        obj : any
            An object to be put.

        Returns
        -------
        PandasOnDaskDataframePartition
            A new ``PandasOnDaskDataframePartition`` object.
        """
        client = default_client()
        return cls(client.scatter(obj, hash=False))
示例#30
0
    def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):

        if format == "memory":
            client = default_client()
            return client.get_dataset(input_item, **kwargs)

        if not format:
            _, extension = os.path.splitext(input_item)

            format = extension.lstrip(".")

        try:
            read_function = getattr(dd, f"read_{format}")
        except AttributeError:
            raise AttributeError(f"Can not read files of format {format}")

        return read_function(input_item, **kwargs)
示例#31
0
文件: core.py 项目: gameduell/dask
def read_bytes(urlpath, delimiter=None, not_zero=False, blocksize=2**27,
               sample=True, compression=None, **kwargs):
    """ Convert path to a list of delayed values

    The path may be a filename like ``'2015-01-01.csv'`` or a globstring
    like ``'2015-*-*.csv'``.

    The path may be preceded by a protocol, like ``s3://`` or ``hdfs://`` if
    those libraries are installed.

    This cleanly breaks data by a delimiter if given, so that block boundaries
    start directly after a delimiter and end on the delimiter.

    Parameters
    ----------
    urlpath: string
        Absolute or relative filepath, URL (may include protocols like
        ``s3://``), or globstring pointing to data.
    delimiter: bytes
        An optional delimiter, like ``b'\\n'`` on which to split blocks of
        bytes.
    not_zero: bool
        Force seek of start-of-file delimiter, discarding header.
    blocksize: int (=128MB)
        Chunk size
    compression: string or None
        String like 'gzip' or 'xz'.  Must support efficient random access.
    sample: bool or int
        Whether or not to return a header sample. If an integer is given it is
        used as sample size, otherwise the default sample size is 10kB.
    **kwargs: dict
        Extra options that make sense to a particular storage connection, e.g.
        host, port, username, password, etc.

    Examples
    --------
    >>> sample, blocks = read_bytes('2015-*-*.csv', delimiter=b'\\n')  # doctest: +SKIP
    >>> sample, blocks = read_bytes('s3://bucket/2015-*-*.csv', delimiter=b'\\n')  # doctest: +SKIP

    Returns
    -------
    A sample header and list of ``dask.Delayed`` objects or list of lists of
    delayed objects if ``fn`` is a globstring.
    """
    fs, paths, myopen = get_fs_paths_myopen(urlpath, compression, 'rb',
                                            None, **kwargs)
    client = None
    if len(paths) == 0:
        raise IOError("%s resolved to no files" % urlpath)

    blocks, lengths, machines = fs.get_block_locations(paths)
    if blocks:
        offsets = blocks
    elif blocksize is None:
        offsets = [[0]] * len(paths)
        lengths = [[None]] * len(offsets)
        machines = [[None]] * len(offsets)
    else:
        offsets = []
        lengths = []
        for path in paths:
            try:
                size = fs.logical_size(path, compression)
            except KeyError:
                raise ValueError('Cannot read compressed files (%s) in byte chunks,'
                                 'use blocksize=None' % infer_compression(urlpath))
            off = list(range(0, size, blocksize))
            length = [blocksize] * len(off)
            if not_zero:
                off[0] = 1
                length[0] -= 1
            offsets.append(off)
            lengths.append(length)
        machines = [[None]] * len(offsets)

    out = []
    for path, offset, length, machine in zip(paths, offsets, lengths, machines):
        ukey = fs.ukey(path)
        keys = ['read-block-%s-%s' %
                (o, tokenize(path, compression, offset, ukey, kwargs, delimiter))
                for o in offset]
        L = [delayed(read_block_from_file)(myopen(path, mode='rb'), o,
                                           l, delimiter, dask_key_name=key)
             for (o, key, l) in zip(offset, keys, length)]
        out.append(L)
        if machine is not None:  # blocks are in preferred locations
            if client is None:
                try:
                    from distributed.client import default_client
                    client = default_client()
                except (ImportError, ValueError):  # no distributed client
                    client = False
            if client:
                restrictions = {key: w for key, w in zip(keys, machine)}
                client._send_to_scheduler({'op': 'update-graph', 'tasks': {},
                                           'dependencies': [], 'keys': [],
                                           'restrictions': restrictions,
                                           'loose_restrictions': list(restrictions),
                                           'client': client.id})

    if sample is not True:
        nbytes = sample
    else:
        nbytes = 10000
    if sample:
        # myopen = OpenFileCreator(urlpath, compression)
        with myopen(paths[0], 'rb') as f:
            sample = read_block(f, 0, nbytes, delimiter)
    return sample, out
示例#32
0
文件: base.py 项目: mmngreco/dask
def persist(*args, **kwargs):
    """ Persist multiple Dask collections into memory

    This turns lazy Dask collections into Dask collections with the same
    metadata, but now with their results fully computed or actively computing
    in the background.

    For example a lazy dask.array built up from many lazy calls will now be a
    dask.array of the same shape, dtype, chunks, etc., but now with all of
    those previously lazy tasks either computed in memory as many small :class:`numpy.array`
    (in the single-machine case) or asynchronously running in the
    background on a cluster (in the distributed case).

    This function operates differently if a ``dask.distributed.Client`` exists
    and is connected to a distributed scheduler.  In this case this function
    will return as soon as the task graph has been submitted to the cluster,
    but before the computations have completed.  Computations will continue
    asynchronously in the background.  When using this function with the single
    machine scheduler it blocks until the computations have finished.

    When using Dask on a single machine you should ensure that the dataset fits
    entirely within memory.

    Examples
    --------
    >>> df = dd.read_csv('/path/to/*.csv')  # doctest: +SKIP
    >>> df = df[df.name == 'Alice']  # doctest: +SKIP
    >>> df['in-debt'] = df.balance < 0  # doctest: +SKIP
    >>> df = df.persist()  # triggers computation  # doctest: +SKIP

    >>> df.value().min()  # future computations are now fast  # doctest: +SKIP
    -10
    >>> df.value().max()  # doctest: +SKIP
    100

    >>> from dask import persist  # use persist function on multiple collections
    >>> a, b = persist(a, b)  # doctest: +SKIP

    Parameters
    ----------
    *args: Dask collections
    get : callable, optional
        A scheduler ``get`` function to use. If not provided, the default
        is to check the global settings first, and then fall back to
        the collection defaults.
    optimize_graph : bool, optional
        If True [default], the graph is optimized before computation.
        Otherwise the graph is run as is. This can be useful for debugging.
    **kwargs
        Extra keywords to forward to the scheduler ``get`` function.

    Returns
    -------
    New dask collections backed by in-memory data
    """
    collections = [a for a in args if is_dask_collection(a)]
    if not collections:
        return args

    get = kwargs.pop('get', None) or _globals['get']

    if get is None and getattr(thread_state, 'key', False):
        from distributed.worker import get_worker
        get = get_worker().client.get

    if inspect.ismethod(get):
        try:
            from distributed.client import default_client
        except ImportError:
            pass
        else:
            try:
                client = default_client()
            except ValueError:
                pass
            else:
                if client.get == _globals['get']:
                    collections = client.persist(collections, **kwargs)
                    if isinstance(collections, list):  # distributed is inconsistent here
                        collections = tuple(collections)
                    else:
                        collections = (collections,)
                    results_iter = iter(collections)
                    return tuple(a if not is_dask_collection(a)
                                 else next(results_iter)
                                 for a in args)

    optimize_graph = kwargs.pop('optimize_graph', True)

    if not get:
        get = collections[0].__dask_scheduler__
        if not all(a.__dask_scheduler__ == get for a in collections):
            raise ValueError("Compute called on multiple collections with "
                             "differing default schedulers. Please specify a "
                             "scheduler `get` function using either "
                             "the `get` kwarg or globally with `set_options`.")

    dsk = collections_to_dsk(collections, optimize_graph, **kwargs)

    keys, postpersists = [], []
    for a in args:
        if is_dask_collection(a):
            a_keys = list(flatten(a.__dask_keys__()))
            rebuild, state = a.__dask_postpersist__()
            keys.extend(a_keys)
            postpersists.append((rebuild, a_keys, state))
        else:
            postpersists.append((None, None, a))

    results = get(dsk, keys, **kwargs)
    d = dict(zip(keys, results))
    return tuple(s if r is None else r({k: d[k] for k in ks}, *s)
                 for r, ks, s in postpersists)
示例#33
0
文件: base.py 项目: floriango/dask
def persist(*args, **kwargs):
    """ Persist multiple Dask collections into memory

    This turns lazy Dask collections into Dask collections with the same
    metadata, but now with their results fully computed or actively computing
    in the background.

    For example a lazy dask.array built up from many lazy calls will now be a
    dask.array of the same shape, dtype, chunks, etc., but now with all of
    those previously lazy tasks either computed in memory as many small :class:`numpy.array`
    (in the single-machine case) or asynchronously running in the
    background on a cluster (in the distributed case).

    This function operates differently if a ``dask.distributed.Client`` exists
    and is connected to a distributed scheduler.  In this case this function
    will return as soon as the task graph has been submitted to the cluster,
    but before the computations have completed.  Computations will continue
    asynchronously in the background.  When using this function with the single
    machine scheduler it blocks until the computations have finished.

    When using Dask on a single machine you should ensure that the dataset fits
    entirely within memory.

    Examples
    --------
    >>> df = dd.read_csv('/path/to/*.csv')  # doctest: +SKIP
    >>> df = df[df.name == 'Alice']  # doctest: +SKIP
    >>> df['in-debt'] = df.balance < 0  # doctest: +SKIP
    >>> df = df.persist()  # triggers computation  # doctest: +SKIP

    >>> df.value().min()  # future computations are now fast  # doctest: +SKIP
    -10
    >>> df.value().max()  # doctest: +SKIP
    100

    >>> from dask import persist  # use persist function on multiple collections
    >>> a, b = persist(a, b)  # doctest: +SKIP

    Parameters
    ----------
    *args: Dask collections
    scheduler : string, optional
        Which scheduler to use like "threads", "synchronous" or "processes".
        If not provided, the default is to check the global settings first,
        and then fall back to the collection defaults.
    traverse : bool, optional
        By default dask traverses builtin python collections looking for dask
        objects passed to ``persist``. For large collections this can be
        expensive. If none of the arguments contain any dask objects, set
        ``traverse=False`` to avoid doing this traversal.
    optimize_graph : bool, optional
        If True [default], the graph is optimized before computation.
        Otherwise the graph is run as is. This can be useful for debugging.
    **kwargs
        Extra keywords to forward to the scheduler function.

    Returns
    -------
    New dask collections backed by in-memory data
    """
    traverse = kwargs.pop('traverse', True)
    optimize_graph = kwargs.pop('optimize_graph', True)

    collections, repack = unpack_collections(*args, traverse=traverse)
    if not collections:
        return args

    schedule = get_scheduler(get=kwargs.pop('get', None),
                             scheduler=kwargs.pop('scheduler', None),
                             collections=collections)

    if inspect.ismethod(schedule):
        try:
            from distributed.client import default_client
        except ImportError:
            pass
        else:
            try:
                client = default_client()
            except ValueError:
                pass
            else:
                if client.get == schedule:
                    results = client.persist(collections,
                                             optimize_graph=optimize_graph,
                                             **kwargs)
                    return repack(results)

    dsk = collections_to_dsk(collections, optimize_graph, **kwargs)
    keys, postpersists = [], []
    for a in collections:
        a_keys = list(flatten(a.__dask_keys__()))
        rebuild, state = a.__dask_postpersist__()
        keys.extend(a_keys)
        postpersists.append((rebuild, a_keys, state))

    results = schedule(dsk, keys, **kwargs)
    d = dict(zip(keys, results))
    results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists]
    return repack(results2)