def _dask_load(sources, geobox, measurements, dask_chunks, skip_broken_datasets=False): needed_irr_chunks, grid_chunks = _calculate_chunk_sizes( sources, geobox, dask_chunks) gbt = GeoboxTiles(geobox, grid_chunks) dsk = {} def chunk_datasets(dss, gbt): out = {} for ds in dss: dsk[_tokenize_dataset(ds)] = ds for idx in gbt.tiles(ds.extent): out.setdefault(idx, []).append(ds) return out chunked_srcs = xr_apply(sources, lambda _, dss: chunk_datasets(dss, gbt), dtype=object) def data_func(measurement): return _make_dask_array(chunked_srcs, dsk, gbt, measurement, chunks=needed_irr_chunks + grid_chunks, skip_broken_datasets=skip_broken_datasets) return Datacube.create_storage(sources.coords, geobox, measurements, data_func)
def reproject_band(band, geobox, resampling, dims, dask_chunks=None): """ Reproject a single measurement to the geobox. """ if not hasattr(band.data, 'dask') or dask_chunks is None: data = reproject_array(band.data, band.nodata, band.geobox, geobox, resampling) return wrap_in_dataarray(data, band, geobox, dims) dask_name = 'warp_{name}-{token}'.format(name=band.name, token=uuid.uuid4().hex) dependencies = [band.data] spatial_chunks = tuple( dask_chunks.get(k, geobox.shape[i]) for i, k in enumerate(geobox.dims)) gt = GeoboxTiles(geobox, spatial_chunks) new_layer = {} for tile_index in numpy.ndindex(gt.shape): sub_geobox = gt[tile_index] # find the input array slice from the output geobox reproject_roi = compute_reproject_roi(band.geobox, sub_geobox, padding=1) # find the chunk from the input array with the slice index subset_band = band[(..., ) + reproject_roi.roi_src].chunk(-1) if min(subset_band.shape) == 0: # pad the empty chunk new_layer[(dask_name, ) + tile_index] = (numpy.full, sub_geobox.shape, band.nodata, band.dtype) else: # next 3 lines to generate the new graph dependencies.append(subset_band.data) # get the input dask array for the function `reproject_array` band_key = list(flatten(subset_band.data.__dask_keys__()))[0] # generate a new layer of dask graph with reroject new_layer[(dask_name, ) + tile_index] = (reproject_array, band_key, band.nodata, subset_band.geobox, sub_geobox, resampling) # create a new graph with the additional layer and pack the graph into dask.array # since only regular chunking is allowed at the higher level dask.array interface, # to manipulate the graph seems to be the easiest way to obtain a dask.array with irregular chunks after reproject data = dask.array.Array(band.data.dask.from_collections( dask_name, new_layer, dependencies=dependencies), dask_name, chunks=spatial_chunks, dtype=band.dtype, shape=gt.base.shape) return wrap_in_dataarray(data, band, geobox, dims)
def dask_reproject( src: da.Array, src_geobox: GeoBox, dst_geobox: GeoBox, resampling: str = "nearest", chunks: Optional[Tuple[int, int]] = None, src_nodata: Optional[NodataType] = None, dst_nodata: Optional[NodataType] = None, axis: int = 0, name: str = "reproject", ) -> da.Array: """ Reproject to GeoBox as dask operation :param src : Input src[(time,) y,x (, band)] :param src_geobox: GeoBox of the source array :param dst_geobox: GeoBox of the destination :param resampling: Resampling strategy as a string: nearest, bilinear, average, mode ... :param chunks : In Y,X dimensions only, default is to use same input chunk size :param axis : Index of Y axis (default is 0) :param src_nodata: nodata marker for source image :param dst_nodata: nodata marker for dst image :param name : Dask graph name, "reproject" is the default """ if chunks is None: chunks = src.chunksize[axis:axis + 2] if dst_nodata is None: dst_nodata = src_nodata assert src.shape[axis:axis + 2] == src_geobox.shape yx_shape = dst_geobox.shape yx_chunks = unpack_chunks(chunks, yx_shape) dst_chunks = src.chunks[:axis] + yx_chunks + src.chunks[axis + 2:] dst_shape = src.shape[:axis] + yx_shape + src.shape[axis + 2:] # tuple(*dims1, y, x, *dims2) -- complete shape in blocks dims1 = tuple(map(len, dst_chunks[:axis])) dims2 = tuple(map(len, dst_chunks[axis + 2:])) assert dims2 == () deps = [src] tile_shape = (yx_chunks[0][0], yx_chunks[1][0]) gbt = GeoboxTiles(dst_geobox, tile_shape) xy_chunks_with_data = list(gbt.tiles(src_geobox.extent)) name = randomize(name) dsk: Dict[Any, Any] = {} block_impl = (_reproject_block_bool_impl if src.dtype == "bool" else _reproject_block_impl) for idx in xy_chunks_with_data: _dst_geobox = gbt[idx] rr = compute_reproject_roi(src_geobox, _dst_geobox) _src = crop_2d_dense(src, rr.roi_src, axis=axis) _src_geobox = src_geobox[rr.roi_src] deps.append(_src) for ii1 in np.ndindex(dims1): # TODO: band dims dsk[(name, *ii1, *idx)] = ( block_impl, (_src.name, *ii1, 0, 0), _src_geobox, _dst_geobox, resampling, src_nodata, dst_nodata, axis, ) fill_value = 0 if dst_nodata is None else dst_nodata shape_in_blocks = tuple(map(len, dst_chunks)) mk_empty = empty_maker(fill_value, src.dtype, dsk) for idx in np.ndindex(shape_in_blocks): # TODO: other dims k = (name, *idx) if k not in dsk: bshape = tuple(ch[i] for ch, i in zip(dst_chunks, idx)) dsk[k] = mk_empty(bshape) dsk = HighLevelGraph.from_collections(name, dsk, dependencies=deps) return da.Array(dsk, name, chunks=dst_chunks, dtype=src.dtype, shape=dst_shape)