示例#1
0
    def _dask_load(sources,
                   geobox,
                   measurements,
                   dask_chunks,
                   skip_broken_datasets=False):
        needed_irr_chunks, grid_chunks = _calculate_chunk_sizes(
            sources, geobox, dask_chunks)
        gbt = GeoboxTiles(geobox, grid_chunks)
        dsk = {}

        def chunk_datasets(dss, gbt):
            out = {}
            for ds in dss:
                dsk[_tokenize_dataset(ds)] = ds
                for idx in gbt.tiles(ds.extent):
                    out.setdefault(idx, []).append(ds)
            return out

        chunked_srcs = xr_apply(sources,
                                lambda _, dss: chunk_datasets(dss, gbt),
                                dtype=object)

        def data_func(measurement):
            return _make_dask_array(chunked_srcs,
                                    dsk,
                                    gbt,
                                    measurement,
                                    chunks=needed_irr_chunks + grid_chunks,
                                    skip_broken_datasets=skip_broken_datasets)

        return Datacube.create_storage(sources.coords, geobox, measurements,
                                       data_func)
示例#2
0
def reproject_band(band, geobox, resampling, dims, dask_chunks=None):
    """ Reproject a single measurement to the geobox. """
    if not hasattr(band.data, 'dask') or dask_chunks is None:
        data = reproject_array(band.data, band.nodata, band.geobox, geobox,
                               resampling)
        return wrap_in_dataarray(data, band, geobox, dims)

    dask_name = 'warp_{name}-{token}'.format(name=band.name,
                                             token=uuid.uuid4().hex)
    dependencies = [band.data]

    spatial_chunks = tuple(
        dask_chunks.get(k, geobox.shape[i]) for i, k in enumerate(geobox.dims))

    gt = GeoboxTiles(geobox, spatial_chunks)
    new_layer = {}

    for tile_index in numpy.ndindex(gt.shape):
        sub_geobox = gt[tile_index]
        # find the input array slice from the output geobox
        reproject_roi = compute_reproject_roi(band.geobox,
                                              sub_geobox,
                                              padding=1)

        # find the chunk from the input array with the slice index
        subset_band = band[(..., ) + reproject_roi.roi_src].chunk(-1)

        if min(subset_band.shape) == 0:
            # pad the empty chunk
            new_layer[(dask_name, ) + tile_index] = (numpy.full,
                                                     sub_geobox.shape,
                                                     band.nodata, band.dtype)
        else:
            # next 3 lines to generate the new graph
            dependencies.append(subset_band.data)
            # get the input dask array for the function `reproject_array`
            band_key = list(flatten(subset_band.data.__dask_keys__()))[0]
            # generate a new layer of dask graph with reroject
            new_layer[(dask_name, ) + tile_index] = (reproject_array, band_key,
                                                     band.nodata,
                                                     subset_band.geobox,
                                                     sub_geobox, resampling)

    # create a new graph with the additional layer and pack the graph into dask.array
    # since only regular chunking is allowed at the higher level dask.array interface,
    # to manipulate the graph seems to be the easiest way to obtain a dask.array with irregular chunks after reproject
    data = dask.array.Array(band.data.dask.from_collections(
        dask_name, new_layer, dependencies=dependencies),
                            dask_name,
                            chunks=spatial_chunks,
                            dtype=band.dtype,
                            shape=gt.base.shape)

    return wrap_in_dataarray(data, band, geobox, dims)
示例#3
0
def dask_reproject(
    src: da.Array,
    src_geobox: GeoBox,
    dst_geobox: GeoBox,
    resampling: str = "nearest",
    chunks: Optional[Tuple[int, int]] = None,
    src_nodata: Optional[NodataType] = None,
    dst_nodata: Optional[NodataType] = None,
    axis: int = 0,
    name: str = "reproject",
) -> da.Array:
    """
    Reproject to GeoBox as dask operation

    :param src       : Input src[(time,) y,x (, band)]
    :param src_geobox: GeoBox of the source array
    :param dst_geobox: GeoBox of the destination
    :param resampling: Resampling strategy as a string: nearest, bilinear, average, mode ...
    :param chunks    : In Y,X dimensions only, default is to use same input chunk size
    :param axis      : Index of Y axis (default is 0)
    :param src_nodata: nodata marker for source image
    :param dst_nodata: nodata marker for dst image
    :param name      : Dask graph name, "reproject" is the default
    """
    if chunks is None:
        chunks = src.chunksize[axis:axis + 2]

    if dst_nodata is None:
        dst_nodata = src_nodata

    assert src.shape[axis:axis + 2] == src_geobox.shape
    yx_shape = dst_geobox.shape
    yx_chunks = unpack_chunks(chunks, yx_shape)

    dst_chunks = src.chunks[:axis] + yx_chunks + src.chunks[axis + 2:]
    dst_shape = src.shape[:axis] + yx_shape + src.shape[axis + 2:]

    #  tuple(*dims1, y, x, *dims2) -- complete shape in blocks
    dims1 = tuple(map(len, dst_chunks[:axis]))
    dims2 = tuple(map(len, dst_chunks[axis + 2:]))
    assert dims2 == ()
    deps = [src]

    tile_shape = (yx_chunks[0][0], yx_chunks[1][0])
    gbt = GeoboxTiles(dst_geobox, tile_shape)
    xy_chunks_with_data = list(gbt.tiles(src_geobox.extent))

    name = randomize(name)
    dsk: Dict[Any, Any] = {}

    block_impl = (_reproject_block_bool_impl
                  if src.dtype == "bool" else _reproject_block_impl)

    for idx in xy_chunks_with_data:
        _dst_geobox = gbt[idx]
        rr = compute_reproject_roi(src_geobox, _dst_geobox)
        _src = crop_2d_dense(src, rr.roi_src, axis=axis)
        _src_geobox = src_geobox[rr.roi_src]

        deps.append(_src)

        for ii1 in np.ndindex(dims1):
            # TODO: band dims
            dsk[(name, *ii1, *idx)] = (
                block_impl,
                (_src.name, *ii1, 0, 0),
                _src_geobox,
                _dst_geobox,
                resampling,
                src_nodata,
                dst_nodata,
                axis,
            )

    fill_value = 0 if dst_nodata is None else dst_nodata
    shape_in_blocks = tuple(map(len, dst_chunks))

    mk_empty = empty_maker(fill_value, src.dtype, dsk)

    for idx in np.ndindex(shape_in_blocks):
        # TODO: other dims
        k = (name, *idx)
        if k not in dsk:
            bshape = tuple(ch[i] for ch, i in zip(dst_chunks, idx))
            dsk[k] = mk_empty(bshape)

    dsk = HighLevelGraph.from_collections(name, dsk, dependencies=deps)

    return da.Array(dsk,
                    name,
                    chunks=dst_chunks,
                    dtype=src.dtype,
                    shape=dst_shape)