def test_gh_4176(): from dask.sharedict import ShareDict def foo(A): return A[None, ...] A = da.ones(shape=(10, 20, 4), chunks=(2, 5, 4)) name = 'D' dsk = blockwise( foo, name, ("nsrc", "ntime", "nbl", "npol"), A.name, ("ntime", "nbl", "npol"), new_axes={"nsrc": 1}, numblocks={a.name: a.numblocks for a in (A,)} ) array_dsk = ShareDict() array_dsk.update(dsk) array_dsk.update(A.__dask_graph__()) chunks = ((1,),) + A.chunks D = da.Array(array_dsk, name, chunks, dtype=A.dtype) D.sum(axis=0).compute()
def test_gh_4176(): with warnings.catch_warnings(): warnings.simplefilter('ignore') from dask.sharedict import ShareDict def foo(A): return A[None, ...] A = da.ones(shape=(10, 20, 4), chunks=(2, 5, 4)) name = 'D' dsk = blockwise(foo, name, ("nsrc", "ntime", "nbl", "npol"), A.name, ("ntime", "nbl", "npol"), new_axes={"nsrc": 1}, numblocks={a.name: a.numblocks for a in (A, )}) array_dsk = ShareDict() array_dsk.update(dsk) array_dsk.update(A.__dask_graph__()) chunks = ((1, ), ) + A.chunks D = da.Array(array_dsk, name, chunks, dtype=A.dtype) D.sum(axis=0).compute()
def _bda_getitem_row_chan(avg, idx, dtype, format, avg_meta, nchan): """ Extract (row, corr) arrays from dask array of tuples """ f = BDARowChanAverageOutput._fields[idx] name = "row-chan-average-getitem-%s-%s-" % (f, format) name += tokenize(avg, idx) if format == "flat": dims = ("row", "corr") new_axes = None layers = db.blockwise(getitem, name, dims, avg.name, ("row", "corr"), idx, None, numblocks={avg.name: avg.numblocks}) chunks = avg.chunks meta = np.empty((0, 0), dtype=np.object) elif format == "ragged": dims = ("row", "chan", "corr") new_axes = {"chan": nchan} layers = db.blockwise(_ragged_row_chan_getitem, name, dims, avg.name, ("row", "corr"), idx, None, avg_meta.name, ("row", ), new_axes=new_axes, numblocks={ avg.name: avg.numblocks, avg_meta.name: avg_meta.numblocks }) chunks = (avg.chunks[0], (nchan, ), avg.chunks[1]) meta = np.empty((0, 0, 0), dtype=np.object) else: raise ValueError("Invalid format %s" % format) graph = HighLevelGraph.from_collections(name, layers, (avg, )) return da.Array(graph, name, chunks, meta=meta)
def _build_map_layer( func: Callable, prev_name: str, new_name: str, collection, dependencies: tuple[Delayed, ...] = (), ) -> Layer: """Apply func to all keys of collection. Create a Blockwise layer whenever possible; fall back to MaterializedLayer otherwise. Parameters ---------- func Callable to be invoked on the graph node prev_name : str name of the layer to map from; in case of dask base collections, this is the collection name. Note how third-party collections, e.g. xarray.Dataset, can have multiple names. new_name : str name of the layer to map to collection Arbitrary dask collection dependencies Zero or more Delayed objects, which will be passed as arbitrary variadic args to func after the collection's chunk """ if _can_apply_blockwise(collection): # Use a Blockwise layer try: numblocks = collection.numblocks except AttributeError: numblocks = (collection.npartitions, ) indices = tuple(i for i, _ in enumerate(numblocks)) kwargs = { "_deps": [d.key for d in dependencies] } if dependencies else {} return blockwise( func, new_name, indices, prev_name, indices, numblocks={prev_name: numblocks}, dependencies=dependencies, **kwargs, ) else: # Delayed, bag.Item, dataframe.core.Scalar, or third-party collection; # fall back to MaterializedLayer dep_keys = tuple(d.key for d in dependencies) return MaterializedLayer({ replace_name_in_key(k, {prev_name: new_name}): (func, k) + dep_keys for k in flatten(collection.__dask_keys__()) if get_name_from_key(k) == prev_name })
def _getitem_row(avg, idx, dtype): """ Extract row-like arrays from a dask array of tuples """ name = ("row-average-getitem-%d-" % idx) + tokenize(avg, idx) layers = db.blockwise(getitem, name, ("row", ), avg.name, ("row", ), idx, None, numblocks={avg.name: avg.numblocks}) graph = HighLevelGraph.from_collections(name, layers, (avg, )) return da.Array(graph, name, avg.chunks, dtype=dtype)
def _getitem_chan(avg, idx, dtype): """ Extract row-like arrays from a dask array of tuples """ name = ("chan-average-getitem-%d-" % idx) + tokenize(avg, idx) layers = db.blockwise(getitem, name, ("chan", ), avg.name, ("chan", ), idx, None, numblocks={avg.name: avg.numblocks}) graph = HighLevelGraph.from_collections(name, layers, (avg, )) kw = {'meta': np.empty((0, ), dtype=dtype)} if PY3 else {} return da.Array(graph, name, avg.chunks, dtype=dtype, **kw)
def _make_pipeline(pipeline: Pipeline) -> Delayed: token = dask.base.tokenize(pipeline) # we are constructing a HighLevelGraph from scratch # https://docs.dask.org/en/latest/high-level-graphs.html layers = dict() # type: Dict[str, Dict[Union[str, Tuple[str, int]], Any]] dependencies = dict() # type: Dict[str, Set[str]] # start with just the config as a standalone layer # create a custom delayed object for the config config_key = append_token("config", token) layers[config_key] = {config_key: pipeline.config} dependencies[config_key] = set() prev_key: str = config_key for stage in pipeline.stages: if stage.mappable is None: stage_key = append_token(stage.name, token) func = wrap_standalone_task(stage.function) layers[stage_key] = {stage_key: (func, config_key, prev_key)} dependencies[stage_key] = {config_key, prev_key} else: func = wrap_map_task(stage.function) map_key = append_token(stage.name, token) layers[map_key] = map_layer = blockwise( func, map_key, "x", # <-- dimension name doesn't matter BlockwiseDepDict({(i, ): x for i, x in enumerate(stage.mappable)}), # ^ this is extra annoying. `BlockwiseDepList` at least would be nice. "x", config_key, None, prev_key, None, numblocks={}, # ^ also annoying; the default of None breaks Blockwise ) dependencies[map_key] = {config_key, prev_key} stage_key = f"{stage.name}-checkpoint-{token}" layers[stage_key] = { stage_key: (checkpoint, *map_layer.get_output_keys()) } dependencies[stage_key] = {map_key} prev_key = stage_key hlg = HighLevelGraph(layers, dependencies) delayed = Delayed(prev_key, hlg) return delayed
def _bda_getitem_row(avg, idx, array, dims, meta, format="flat"): """ Extract row-like arrays from a dask array of tuples """ assert dims[0] == "row" name = "row-average-getitem-%s-" % idx name += tokenize(avg, idx) new_axes = dict(zip(dims[1:], array.shape[1:])) numblocks = {avg.name: avg.numblocks} if format == "flat": layers = db.blockwise(getitem, name, dims, avg.name, ("row", ), idx, None, new_axes=new_axes, numblocks=numblocks) elif format == "ragged": numblocks[meta.name] = meta.numblocks layers = db.blockwise(_ragged_row_getitem, name, dims, avg.name, ("row", ), idx, None, meta.name, ("row", ), new_axes=new_axes, numblocks=numblocks) else: raise ValueError("Invalid format %s" % format) graph = HighLevelGraph.from_collections(name, layers, (avg, )) chunks = avg.chunks + tuple((s, ) for s in array.shape[1:]) meta = np.empty((0, ) * len(dims), dtype=array.dtype) return da.Array(graph, name, chunks, meta=meta)
def _getitem_row_chan(avg, idx, dtype): """ Extract (row,chan,corr) arrays from dask array of tuples """ name = ("row-chan-average-getitem-%d-" % idx) + tokenize(avg, idx) dim = ("row", "chan", "corr") layers = db.blockwise(getitem, name, dim, avg.name, dim, idx, None, numblocks={avg.name: avg.numblocks}) graph = HighLevelGraph.from_collections(name, layers, (avg, )) return da.Array(graph, name, avg.chunks, dtype=dtype)
def uvcontsub_flagger(vis, flag, **kwargs): """ Dask wrapper for :func:`~tricolour.uvcontsub_flagger` """ name = 'uvcontsub-flagger-' + da.core.tokenize(vis, flag, **kwargs) layers = db.blockwise(np_uvcontsub_flagger, name, _WINDOW_SCHEMA, vis.name, _WINDOW_SCHEMA, flag.name, _WINDOW_SCHEMA, numblocks={ vis.name: vis.numblocks, flag.name: flag.numblocks, }, **kwargs) # Add input graphs to the graph graph = HighLevelGraph.from_collections(name, layers, (vis, flag)) return da.Array(graph, name, vis.chunks, dtype=flag.dtype)
def _getitem_row(avg, idx, array, dims): """ Extract row-like arrays from a dask array of tuples """ assert dims[0] == "row" name = ("row-average-getitem-%d-" % idx) + tokenize(avg, idx) layers = db.blockwise(getitem, name, dims, avg.name, ("row", ), idx, None, new_axes=dict(zip(dims[1:], array.shape[1:])), numblocks={avg.name: avg.numblocks}) graph = HighLevelGraph.from_collections(name, layers, (avg, )) chunks = avg.chunks + tuple((s, ) for s in array.shape[1:]) # dask py2 doesn't understand meta kw = {'meta': np.empty((0, ) * len(dims), dtype=np.object)} if PY3 else {} return da.Array(graph, name, chunks, dtype=array.dtype, **kw)
def sum_threshold_flagger(vis, flag, **kwargs): """ Dask wrapper for :func:`~tricolour.flagging.sum_threshold_flagger` """ # We use dask.blockwise.blockwise rather than dask.array.blockwise because, # while # ant1, ant2 and chunks will have the same number of chunks, # the size of each chunk is different token = da.core.tokenize(vis, flag, kwargs) name = 'sum-threshold-flagger-' + token layers = db.blockwise(np_sum_threshold_flagger, name, _WINDOW_SCHEMA, vis.name, _WINDOW_SCHEMA, flag.name, _WINDOW_SCHEMA, numblocks={ vis.name: vis.numblocks, flag.name: flag.numblocks, }, **kwargs) # Add input graphs to the graph graph = HighLevelGraph.from_collections(name, layers, (vis, flag)) return da.Array(graph, name, vis.chunks, dtype=flag.dtype)
def generate_table_getcols(table_name, column, shape, dtype, table_proxy, row_runs, row_resorts=None): """ Generates a :class:`dask.array.Array` representing ``column`` in ``table_name`` and backed by a series of :meth:`casacore.tables.table.getcol` commands. Parameters ---------- table_name : str CASA table filename path column : str Name of the column to generate shape : tuple Shape of the array dtype : np.dtype or object Data type of the array table_proxy : :class:`TableProxy` Table proxy object row_runs : list of :class:`numpy.ndarray` List of row runs for each chunk row_resorts : list of :class:`numpy.ndarray` or list of None List of argsort indices to apply for each row run. A None entry indicates no resorting is applied. Returns ------- :class:`dask.array.Arrays` Dask array representing the column """ token = dask.base.tokenize(table_name, column, row_runs, row_resorts) short_name = short_table_name(table_name) name = '-'.join((short_name, "getcol", column.lower(), token)) # Integer dimension schema. 'row' == 0 schema = tuple(range(len(shape))) # infer the type of getter we should be using if isinstance(dtype, np.dtype): _get_fn = _chunk_getcols_np else: _get_fn = _chunk_getcols_object layers = db.blockwise(_get_fn, name, schema, table_proxy, None, column, None, shape[1:], None, dtype, None, row_runs.name, schema[0:1], row_resorts.name, schema[0:1], new_axes={i + 1: s for i, s in enumerate(shape[1:])}, numblocks={ row_runs.name: row_runs.numblocks, row_resorts.name: row_resorts.numblocks, }) graph = HighLevelGraph.from_collections(name, layers, (row_runs, row_resorts)) chunks = row_runs.chunks + tuple((d, ) for d in shape[1:]) return da.Array(graph, name, chunks, dtype=dtype)
def apply_dies(time_index, antenna1, antenna2, die1_jones, base_vis, die2_jones, predict_check_tup, out_dtype): """ Apply any Direction-Independent Effects and Base Visibilities """ # Now apply any Direction Independent Effect Terms (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = predict_check_tup have_dies = have_dies1 and have_dies2 # Generate strings for the correlation dimensions # This also has the effect of checking that we have all valid inputs if have_dies: cdims = tuple("corr-%d" % i for i in range(len(die1_jones.shape[3:]))) elif have_bvis: cdims = tuple("corr-%d" % i for i in range(len(base_vis.shape[2:]))) else: raise ValueError("Missing both antenna and baseline jones terms") # In the case of predict_vis, the "row" and "time" dimensions # are intimately related -- a contiguous series of rows # are related to a contiguous series of timesteps. # This means that the number of chunks of these # two dimensions must match even though the chunk sizes may not. # blockwise insists on matching chunk sizes. # For this reason, we use the lower level blockwise and # substitute "row" for "time" in arrays such as dde1_jones # and die1_jones. gjones_dims = ("row", "ant", "chan") + cdims # Setup # 1. Optional blockwise arguments # 2. Optional numblocks kwarg # 3. HighLevelGraph dependencies bw_args = [time_index.name, ("row",), antenna1.name, ("row",), antenna2.name, ("row",)] numblocks = { time_index.name: time_index.numblocks, antenna1.name: antenna1.numblocks, antenna2.name: antenna2.numblocks } deps = [time_index, antenna1, antenna2] # dde1_jones, source_coh and dde2_jones not present # these are already applied into sum_coherencies bw_args.extend([None, None, None, None, None, None]) if have_dies: bw_args.extend([die1_jones.name, gjones_dims]) numblocks[die1_jones.name] = die1_jones.numblocks deps.append(die1_jones) other_chunks = die1_jones.chunks[2:] else: bw_args.extend([None, None]) if have_bvis: bw_args.extend([base_vis.name, ("row", "chan") + cdims]) numblocks[base_vis.name] = base_vis.numblocks deps.append(base_vis) other_chunks = base_vis.chunks[1:] else: bw_args.extend([None, None]) if have_dies: bw_args.extend([die2_jones.name, gjones_dims]) numblocks[die2_jones.name] = die2_jones.numblocks deps.append(die2_jones) other_chunks = die2_jones.chunks[2:] else: bw_args.extend([None, None]) assert len(bw_args) // 2 == 9 token = da.core.tokenize(time_index, antenna1, antenna2, die1_jones, base_vis, die2_jones) name = '-'.join(("predict-vis-apply-dies", token)) layer = blockwise(_predict_dies_wrapper, name, ("row", "chan") + cdims, *bw_args, numblocks=numblocks) graph = HighLevelGraph.from_collections(name, layer, deps) chunks = (time_index.chunks[0],) + other_chunks return da.Array(graph, name, chunks, dtype=out_dtype)
def fan_reduction(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, predict_check_tup, out_dtype): """ Does a standard dask tree reduction over source coherencies """ (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = predict_check_tup have_ddes = have_ddes1 and have_ddes2 if have_ddes: cdims = tuple("corr-%d" % i for i in range(len(dde1_jones.shape[4:]))) elif have_coh: cdims = tuple("corr-%d" % i for i in range(len(source_coh.shape[3:]))) else: raise ValueError("need ddes or source coherencies") ajones_dims = ("src", "row", "ant", "chan") + cdims # Setup # 1. Optional blockwise arguments # 2. Optional numblocks kwarg # 3. HighLevelGraph dependencies bw_args = [time_index.name, ("row",), antenna1.name, ("row",), antenna2.name, ("row",)] numblocks = { time_index.name: time_index.numblocks, antenna1.name: antenna1.numblocks, antenna2.name: antenna2.numblocks } # Dependencies deps = [time_index, antenna1, antenna2] # Handle presence/absence of dde1_jones if have_ddes: bw_args.extend([dde1_jones.name, ajones_dims]) numblocks[dde1_jones.name] = dde1_jones.numblocks deps.append(dde1_jones) other_chunks = dde1_jones.chunks[3:] src_chunks = dde1_jones.chunks[0] else: bw_args.extend([None, None]) # Handle presence/absence of source_coh if have_coh: bw_args.extend([source_coh.name, ("src", "row", "chan") + cdims]) numblocks[source_coh.name] = source_coh.numblocks deps.append(source_coh) other_chunks = source_coh.chunks[2:] src_chunks = source_coh.chunks[0] else: bw_args.extend([None, None]) # Handle presence/absence of dde2_jones if have_ddes: bw_args.extend([dde2_jones.name, ajones_dims]) numblocks[dde2_jones.name] = dde2_jones.numblocks deps.append(dde2_jones) other_chunks = dde2_jones.chunks[3:] src_chunks = dde2_jones.chunks[0] else: bw_args.extend([None, None]) # die1_jones, base_vis and die2_jones absent for this part of the graph bw_args.extend([None, None, None, None, None, None]) assert len(bw_args) // 2 == 9, len(bw_args) // 2 token = da.core.tokenize(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones) name = "-".join(("predict-vis-sum-coh", token)) layer = blockwise(_predict_coh_wrapper, name, ("src", "row", "chan") + cdims, *bw_args, numblocks=numblocks) graph = HighLevelGraph.from_collections(name, layer, deps) # We can infer output chunk sizes from source_coh chunks = ((1,)*len(src_chunks), time_index.chunks[0],) + other_chunks # Create array sum_coherencies = da.Array(graph, name, chunks, dtype=out_dtype) # Reduce source axis return sum_coherencies.sum(axis=0)
def from_dask_array(x, columns=None, index=None, meta=None): """Create a Dask DataFrame from a Dask Array. Converts a 2d array into a DataFrame and a 1d array into a Series. Parameters ---------- x : da.Array columns : list or string list of column names if DataFrame, single string if Series index : dask.dataframe.Index, optional An optional *dask* Index to use for the output Series or DataFrame. The default output index depends on whether `x` has any unknown chunks. If there are any unknown chunks, the output has ``None`` for all the divisions (one per chunk). If all the chunks are known, a default index with known divisions is created. Specifying `index` can be useful if you're conforming a Dask Array to an existing dask Series or DataFrame, and you would like the indices to match. meta : object, optional An optional `meta` parameter can be passed for dask to specify the concrete dataframe type to be returned. By default, pandas DataFrame is used. Examples -------- >>> import dask.array as da >>> import dask.dataframe as dd >>> x = da.ones((4, 2), chunks=(2, 2)) >>> df = dd.io.from_dask_array(x, columns=['a', 'b']) >>> df.compute() a b 0 1.0 1.0 1 1.0 1.0 2 1.0 1.0 3 1.0 1.0 See Also -------- dask.bag.to_dataframe: from dask.bag dask.dataframe._Frame.values: Reverse conversion dask.dataframe._Frame.to_records: Reverse conversion """ meta = _meta_from_array(x, columns, index, meta=meta) name = "from-dask-array-" + tokenize(x, columns) graph_dependencies = [x] arrays_and_indices = [x.name, "ij" if x.ndim == 2 else "i"] numblocks = {x.name: x.numblocks} if index is not None: # An index is explicitly given by the caller, so we can pass it through to the # initializer after a few checks. if index.npartitions != x.numblocks[0]: msg = ("The index and array have different numbers of blocks. " "({} != {})".format(index.npartitions, x.numblocks[0])) raise ValueError(msg) divisions = index.divisions graph_dependencies.append(index) arrays_and_indices.extend([index._name, "i"]) numblocks[index._name] = (index.npartitions, ) elif np.isnan(sum(x.shape)): # The shape of the incoming array is not known in at least one dimension. As # such, we can't create an index for the entire output DataFrame and we set # the divisions to None to represent that. divisions = [None] * (len(x.chunks[0]) + 1) else: # The shape of the incoming array is known and we don't have an explicit index. # Create a mapping of chunk number in the incoming array to # (start row, stop row) tuples. These tuples will be used to create a sequential # RangeIndex later on that is continuous over the whole DataFrame. divisions = [0] stop = 0 index_mapping = {} for i, increment in enumerate(x.chunks[0]): stop += increment index_mapping[(i, )] = (divisions[i], stop) divisions.append(stop) divisions[-1] -= 1 arrays_and_indices.extend( [BlockwiseDepDict(mapping=index_mapping), "i"]) if is_series_like(meta): kwargs = { "dtype": x.dtype, "name": meta.name, "initializer": type(meta) } else: kwargs = {"columns": meta.columns, "initializer": type(meta)} blk = blockwise( _partition_from_array, name, "i", *arrays_and_indices, numblocks=numblocks, concatenate=True, # kwargs passed through to the DataFrame/Series initializer **kwargs, ) graph = HighLevelGraph.from_collections(name, blk, dependencies=graph_dependencies) return new_dd_object(graph, name, meta, divisions)
def store_inplace(sources, targets, safe=True, **kwargs): """Evaluate a dask computation and store results in the original numpy arrays. Dask is designed to operate on immutable data: the key for a node in the graph is intended to uniquely identify the value. It's possible to create tasks that modify the backing storage, but it can potentially create race conditions where a value might be replaced either before or after it is used. This function provides safety checks that will raise an exception if there is a risk of this happening. Despite the safety checks, it still requires some user care to be used safely: - The arrays in `targets` must be backed by numpy arrays, with no computations other than slicing. Thus, the dask functions :func:`~dask.array.asarray`, :func:`~dask.array.from_array`, :func:`~dask.array.concatenate` and :func:`~dask.array.stack` are safe. - The target keys must be backed by *distinct* numpy arrays. This is not currently checked (although duplicate keys will be detected). - When creating a target array with :func:`~dask.array.from_array`, ensure that the array has a unique name (e.g., by passing ``name=False``). - The safety check only applies to the sources and targets passed to this function. Any simultaneous use of objects based on the targets is invalid, and afterwards any dask objects based on the targets will be computed with the overwritten values. The safety check is conservative i.e., there may be cases where it will throw an exception even though the operation can be proven to be safe. Each source is rechunked to match the chunks of the target. In cases where the target is backed by a single large numpy array, it may be more efficient to construct a new dask wrapper of that numpy array whose chunking matches the source. Parameters ---------- sources : iterable of :class:`dask.array.Array` Values to compute. targets : iterable of :class:`dask.array.Array` Destinations in which to store the results of computing `sources`, with the same length and matching shapes (the dtypes need not match, as long as they are assignable). safe : bool, optional If true (default), raise an exception if the operation is potentially unsafe. This can be an expensive operation (quadratic in the number of chunks). kwargs : dict Extra arguments are passed to the scheduler Raises ------ UnsafeInplaceError if a data hazard is detected ValueError if the sources and targets have the wrong type or don't match """ if isinstance(sources, da.Array): sources = [sources] targets = [targets] if any(not isinstance(s, da.Array) for s in sources): raise ValueError('All sources must be instances of da.Array') if any(not isinstance(t, da.Array) for t in targets): raise ValueError('All targets must be instances of da.Array') chunked_sources = [ source.rechunk(target.chunks) for source, target in zip(sources, targets) ] if safe: _safe_inplace(chunked_sources, targets) def store(target, source): target[:] = source out_keys = [] layers = {} dependencies = {} store_layers = [] for source, target in zip(chunked_sources, targets): name = 'store-' + source.name + '-' + target.name store_layers.append(name) indices = tuple(range(target.ndim)) layer = blockwise(store, name, indices, target.name, indices, source.name, indices, numblocks={ source.name: source.numblocks, target.name: target.numblocks }) # Replicate behaviour of HighLevelGraph.from_collections layers[name] = layer dependencies[name] = set() for collection in source, target: graph = collection.__dask_graph__() layers.update(graph.layers) dependencies.update(graph.dependencies) dependencies[name].update(collection.__dask_layers__()) out_keys.extend(layer.keys()) # We don't have any outputs from storing, so to form a dask collection # we'll gather up all the output keys into one "root" key and form a # Delayed collection from it. This is similar to what da.store does. root_key = 'store-root-' + str(uuid.uuid4()) layers[root_key] = {root_key: out_keys} dependencies[root_key] = set(store_layers) graph = HighLevelGraph(layers, dependencies) # Ensure that array-appropriate optimizations are performed. graph = da.Array.__dask_optimize__(graph, [root_key]) result = Delayed(root_key, graph) result.compute(optimize_graph=False)
def xds_to_table(xds, table_name, columns=None, **kwargs): """ Generates a dask array which writes the specified columns from an :class:`xarray.Dataset` into the CASA table specified by ``table_name`` when the :meth:`dask.array.Array.compute` method is called. Parameters ---------- xds : :class:`xarray.Dataset` dataset containing the specified columns. table_name : str CASA table path columns : tuple or list, optional list of column names to write to the table. If ``None`` all columns will be written. Returns ------- :class:`dask.array.Array` dask array representing the write to the datset. """ rows = xds.table_row.values min_frag_level = kwargs.get('min_frag_level', False) table_proxy = TableProxy(table_name) if columns is None: columns = xds.data_vars.keys() elif isinstance(columns, string_types): columns = [columns] # Get the DataArrays for each column col_arrays = [getattr(xds, column) for column in columns] writes = [] # Generate the graph for each column for column, data_array in zip(columns, col_arrays): dask_array = data_array.data dims = data_array.dims chunks = dask_array.chunks if dims[0] != 'row': raise ValueError("xds.%s.dims[0] != 'row'" % column) multiple = [(dim, chunk) for dim, chunk in zip(dims[1:], chunks[1:]) if len(chunk) != 1] if len(multiple) > 0: raise ValueError("Column '%s' has multiple chunks in the " "following dimensions '%s'. Only chunking " "in 'row' is currently supported. " "Use 'rechunk' so that the mentioned " "dimensions contain a single chunk." % (column, multiple)) # Get row runs for the row chunks row_runs, row_resorts = get_row_runs(rows, chunks, min_frag_level=min_frag_level, sort_dir="write") # Integer dimension schema. 'row' == 0 schema = tuple(range(len(dask_array.shape))) # Tokenize putcol on the dask arrays token = dask.base.tokenize(table_name, column, dask_array) name = '-'.join( (short_table_name(table_name), "putcol", column, token)) layers = db.blockwise(_chunk_putcols_np, name, schema, table_proxy, None, column, None, dask_array.name, schema, row_runs.name, schema[0:1], row_resorts.name, schema[0:1], numblocks={ dask_array.name: dask_array.numblocks, row_runs.name: row_runs.numblocks, row_resorts.name: row_resorts.numblocks }) deps = [dask_array, row_runs, row_resorts] graph = HighLevelGraph.from_collections(name, layers, deps) chunks = tuple(tuple(1 for c in dc) for dc in dask_array.chunks) write_array = da.Array(graph, name, chunks, dtype=np.bool) # Add the arrays graph to dependencies writes.append(write_array) return da.concatenate([w.ravel() for w in writes])