Пример #1
0
def _get_datatype_from_inputs(data):

    """
    Gets the datatype from a distributed data input.

    Parameters
    ----------

    data : dask.DataFrame, dask.Series, dask.Array, or
           Iterable containing either.

    Returns
    -------

    datatype : str {'cupy', 'cudf}
    """

    multiple = isinstance(data, Sequence)

    if isinstance(first(data) if multiple else data,
                  (daskSeries, daskDataFrame, dcDataFrame, dcSeries)):
        datatype = 'cudf'
    else:
        datatype = 'cupy'
        if multiple:
            for d in data:
                validate_dask_array(d)
        else:
            validate_dask_array(data)

    return datatype, multiple
Пример #2
0
def test_validate_dask_array(nrows, ncols, n_parts, col_chunking, n_col_chunks,
                             client):
    if ncols > 1:
        X = cp.random.standard_normal((nrows, ncols))
        X = dask.array.from_array(X, chunks=(nrows / n_parts, -1))
        if col_chunking:
            X = X.rechunk((nrows / n_parts, ncols / n_col_chunks))
    else:
        X = cp.random.standard_normal(nrows)
        X = dask.array.from_array(X, chunks=(nrows / n_parts))

    if col_chunking and ncols > 1:
        with pytest.raises(Exception):
            validate_dask_array(X, client)
    else:
        validate_dask_array(X, client)
        assert True
Пример #3
0
    def create(cls, data, client=None):
        """
        Creates a distributed data handler instance with the given
        distributed data set(s).

        Parameters
        ----------

        data : dask.array, dask.dataframe, or unbounded Sequence of
               dask.array or dask.dataframe.

        client : dask.distributedClient
        """

        client = cls.get_client(client)

        multiple = isinstance(data, Sequence)

        if isinstance(
                first(data) if multiple else data, (dcDataFrame, daskSeries)):
            datatype = 'cudf'
        else:
            datatype = 'cupy'
            if multiple:
                for d in data:
                    validate_dask_array(d)
            else:
                validate_dask_array(data)

        gpu_futures = client.sync(_extract_partitions, data, client)

        workers = tuple(set(map(lambda x: x[0], gpu_futures)))

        return DistributedDataHandler(gpu_futures=gpu_futures,
                                      workers=workers,
                                      datatype=datatype,
                                      multiple=multiple,
                                      client=client)