Example #1
0
def rectify_dataset(source_ds: xr.Dataset,
                    *,
                    var_names: Union[str, Sequence[str]] = None,
                    source_gm: GridMapping = None,
                    xy_var_names: Tuple[str, str] = None,
                    target_gm: GridMapping = None,
                    tile_size: Union[int, Tuple[int, int]] = None,
                    is_j_axis_up: bool = None,
                    output_ij_names: Tuple[str, str] = None,
                    compute_subset: bool = True,
                    uv_delta: float = 1e-3) -> Optional[xr.Dataset]:
    """
    Reproject dataset *source_ds* using its per-pixel
    x,y coordinates or the given *source_gm*.

    The function expects *source_ds* or the given
    *source_gm* to have either one- or two-dimensional
    coordinate variables that provide spatial x,y coordinates
    for every data variable with the same spatial dimensions.

    For example, a dataset may comprise variables with
    spatial dimensions ``var(..., y_dim, x_dim)``,
    then one the function expects coordinates to be provided
    in two forms:

    1. One-dimensional ``x_var(x_dim)``
       and ``y_var(y_dim)`` (coordinate) variables.
    2. Two-dimensional ``x_var(y_dim, x_dim)``
       and ``y_var(y_dim, x_dim)`` (coordinate) variables.

    If *target_gm* is given and it defines a tile size
    or *tile_size* is given, and the number of tiles is
    greater than one in the output's x- or y-direction, then the
    returned dataset will be composed of lazy, chunked dask
    arrays. Otherwise the returned dataset will be composed
    of ordinary numpy arrays.

    :param source_ds: Source dataset grid mapping.
    :param var_names: Optional variable name or sequence of
        variable names.
    :param source_gm: Target dataset grid mapping.
    :param xy_var_names: Optional tuple of the x- and y-coordinate
        variables in *source_ds*. Ignored if *source_gm* is given.
    :param target_gm: Optional output geometry. If not given,
        output geometry will be computed to spatially fit *dataset*
        and to retain its spatial resolution.
    :param tile_size: Optional tile size for the output.
    :param is_j_axis_up: Whether y coordinates are increasing with
        positive image j axis.
    :param output_ij_names: If given, a tuple of variable names in
        which to store the computed source pixel coordinates in
        the returned output.
    :param compute_subset: Whether to compute a spatial subset
        from *dataset* using *output_geom*. If set, the function
        may return ``None`` in case there is no overlap.
    :param uv_delta: A normalized value that is used to determine
        whether x,y coordinates in the output are contained
        in the triangles defined by the input x,y coordinates.
        The higher this value, the more inaccurate the rectification
        will be.
    :return: a reprojected dataset, or None if the requested output
        does not intersect with *dataset*.
    """
    if source_gm is None:
        source_gm = GridMapping.from_dataset(source_ds,
                                             xy_var_names=xy_var_names)

    src_attrs = dict(source_ds.attrs)

    if target_gm is None:
        target_gm = source_gm.to_regular(tile_size=tile_size)
    elif compute_subset:
        source_ds_subset = select_spatial_subset(
            source_ds,
            xy_bbox=target_gm.xy_bbox,
            ij_border=1,
            xy_border=0.5 * (target_gm.x_res + target_gm.y_res),
            grid_mapping=source_gm)
        if source_ds_subset is None:
            return None
        if source_ds_subset is not source_ds:
            # TODO: GridMapping.from_dataset() may be expensive.
            #   Find a more effective way.
            source_gm = GridMapping.from_dataset(source_ds_subset)
            source_ds = source_ds_subset

    # if src_geo_coding.xy_var_names != output_geom.xy_var_names:
    #     output_geom = output_geom.derive(
    #           xy_var_names=src_geo_coding.xy_var_names
    #     )
    # if src_geo_coding.xy_dim_names != output_geom.xy_dim_names:
    #     output_geom = output_geom.derive(
    #           xy_dim_names=src_geo_coding.xy_dim_names
    #     )

    if tile_size is not None or is_j_axis_up is not None:
        target_gm = target_gm.derive(tile_size=tile_size,
                                     is_j_axis_up=is_j_axis_up)

    src_vars = _select_variables(source_ds, source_gm, var_names)

    if target_gm.is_tiled:
        compute_dst_src_ij_images = _compute_ij_images_xarray_dask
        compute_dst_var_image = _compute_var_image_xarray_dask
    else:
        compute_dst_src_ij_images = _compute_ij_images_xarray_numpy
        compute_dst_var_image = _compute_var_image_xarray_numpy

    dst_src_ij_array = compute_dst_src_ij_images(source_gm, target_gm,
                                                 uv_delta)

    dst_x_dim, dst_y_dim = target_gm.xy_dim_names
    dst_dims = dst_y_dim, dst_x_dim
    dst_ds_coords = target_gm.to_coords()
    dst_vars = dict()
    for src_var_name, src_var in src_vars.items():
        dst_var_dims = src_var.dims[0:-2] + dst_dims
        dst_var_coords = {
            d: src_var.coords[d]
            for d in dst_var_dims if d in src_var.coords
        }
        dst_var_coords.update(
            {d: dst_ds_coords[d]
             for d in dst_var_dims if d in dst_ds_coords})
        dst_var_array = compute_dst_var_image(src_var,
                                              dst_src_ij_array,
                                              fill_value=np.nan)
        dst_var = xr.DataArray(dst_var_array,
                               dims=dst_var_dims,
                               coords=dst_var_coords,
                               attrs=src_var.attrs)
        dst_vars[src_var_name] = dst_var

    if output_ij_names:
        output_i_name, output_j_name = output_ij_names
        dst_ij_coords = {
            d: dst_ds_coords[d]
            for d in dst_dims if d in dst_ds_coords
        }
        dst_vars[output_i_name] = xr.DataArray(dst_src_ij_array[0],
                                               dims=dst_dims,
                                               coords=dst_ij_coords)
        dst_vars[output_j_name] = xr.DataArray(dst_src_ij_array[1],
                                               dims=dst_dims,
                                               coords=dst_ij_coords)

    return xr.Dataset(dst_vars, coords=dst_ds_coords, attrs=src_attrs)
Example #2
0
def rechunk_cube(cube: xr.Dataset,
                 gm: GridMapping,
                 chunks: Optional[Dict[str, int]] = None,
                 tile_size: Optional[Tuple[int, int]] = None) \
        -> Tuple[xr.Dataset, GridMapping]:
    """
    Re-chunk data variables of *cube* so they all share the same chunk
    sizes for their dimensions.

    This functions rechunks *cube* for maximum compatibility with
    the Zarr format. Therefore it removes the "chunks" encoding
    from all variables.

    :param cube: A data cube
    :param gm: The cube's grid mapping
    :param chunks: Optional mapping of dimension names to chunk sizes
    :param tile_size: Optional tile sizes, i.e. chunk size of
        spatial dimensions, given as (width, height)
    :return: A potentially rechunked *cube* and adjusted grid mapping.
    """

    # get initial, common cube chunk sizes from given cube
    cube_chunks = get_dataset_chunks(cube)

    # Given chunks will overwrite initial values
    if chunks:
        for dim_name, size in chunks.items():
            cube_chunks[dim_name] = size

    # Given tile size will overwrite spatial dims
    x_dim_name, y_dim_name = gm.xy_dim_names
    if tile_size is not None:
        cube_chunks[x_dim_name] = tile_size[0]
        cube_chunks[y_dim_name] = tile_size[1]

    # Given grid mapping's tile size will overwrite
    # spatial dims only if missing still
    if gm.tile_size is not None:
        if x_dim_name not in cube_chunks:
            cube_chunks[x_dim_name] = gm.tile_size[0]
        if y_dim_name not in cube_chunks:
            cube_chunks[y_dim_name] = gm.tile_size[1]

    # If there is no chunking required, return identities
    if not cube_chunks:
        return cube, gm

    chunked_cube = xr.Dataset(attrs=cube.attrs)

    # Coordinate variables are always
    # chunked automatically
    chunked_cube = chunked_cube.assign_coords(
        coords={
            var_name: var.chunk({dim_name: 'auto'
                                 for dim_name in var.dims})
            for var_name, var in cube.coords.items()
        })

    # Data variables are chunked according to cube_chunks,
    # or if not specified, by the dimension size.
    chunked_cube = chunked_cube.assign(
        variables={
            var_name: var.chunk({
                dim_name: cube_chunks.get(dim_name, cube.dims[dim_name])
                for dim_name in var.dims
            })
            for var_name, var in cube.data_vars.items()
        })

    # Update chunks encoding for Zarr
    for var_name, var in chunked_cube.variables.items():
        if 'chunks' in var.encoding:
            del var.encoding['chunks']
        # if var.chunks is not None:
        #     # sizes[0] is the first of
        #     # e.g. sizes = (512, 512, 71)
        #     var.encoding.update(chunks=[
        #         sizes[0] for sizes in var.chunks
        #     ])
        # elif 'chunks' in var.encoding:
        #     del var.encoding['chunks']
        # print(f"--> {var_name}: encoding={var.encoding.get('chunks')!r}, chunks={var.chunks!r}")

    # Test whether tile size has changed after re-chunking.
    # If so, we will change the grid mapping too.
    tile_width = cube_chunks.get(x_dim_name)
    tile_height = cube_chunks.get(y_dim_name)
    assert tile_width is not None
    assert tile_height is not None
    tile_size = (tile_width, tile_height)
    if tile_size != gm.tile_size:
        # Note, changing grid mapping tile size may
        # rechunk (2D) coordinates in chunked_cube too
        gm = gm.derive(tile_size=tile_size)

    return chunked_cube, gm
Example #3
0
    def process(self,
                dataset: xr.Dataset,
                geo_coding: GridMapping,
                output_geom: GridMapping,
                output_resampling: str,
                include_non_spatial_vars=False) -> xr.Dataset:
        """
        Perform reprojection using tie-points / ground control points.
        """
        reprojection_info = self.get_reprojection_info(dataset)

        warn_prefix = 'unsupported argument in np-GCP rectification mode'
        if reprojection_info.xy_crs is not None:
            warnings.warn(
                f'{warn_prefix}: ignoring '
                f'reprojection_info.xy_crs = {reprojection_info.xy_crs!r}')
        if reprojection_info.xy_tp_names is not None:
            warnings.warn(
                f'{warn_prefix}: ignoring '
                f'reprojection_info.xy_tp_names = {reprojection_info.xy_tp_names!r}'
            )
        if reprojection_info.xy_gcp_step is not None:
            warnings.warn(
                f'{warn_prefix}: ignoring '
                f'reprojection_info.xy_gcp_step = {reprojection_info.xy_gcp_step!r}'
            )
        if reprojection_info.xy_tp_gcp_step is not None:
            warnings.warn(
                f'{warn_prefix}: ignoring '
                f'reprojection_info.xy_tp_gcp_step = {reprojection_info.xy_tp_gcp_step!r}'
            )
        if output_resampling != 'Nearest':
            warnings.warn(f'{warn_prefix}: ignoring '
                          f'dst_resampling = {output_resampling!r}')
        if include_non_spatial_vars:
            warnings.warn(
                f'{warn_prefix}: ignoring '
                f'include_non_spatial_vars = {include_non_spatial_vars!r}')

        geo_coding = geo_coding.derive(
            xy_var_names=(reprojection_info.xy_names[0],
                          reprojection_info.xy_names[1]))

        dataset = rectify_dataset(dataset,
                                  compute_subset=False,
                                  source_gm=geo_coding,
                                  target_gm=output_geom)
        if output_geom.is_tiled:
            # The following condition may become true,
            # if we have used rectified_dataset(input, ..., is_y_reverse=True)
            # In this case y-chunksizes will also be reversed. So that the first chunk is smaller than any other.
            # Zarr will reject such datasets, when written.
            if dataset.chunks.get('lat')[0] < dataset.chunks.get('lat')[-1]:
                dataset = dataset.chunk({
                    'lat': output_geom.tile_height,
                    'lon': output_geom.tile_width
                })
        if dataset is not None \
                and geo_coding.crs.is_geographic \
                and geo_coding.xy_var_names != ('lon', 'lat'):
            dataset = dataset.rename({
                geo_coding.xy_var_names[0]: 'lon',
                geo_coding.xy_var_names[1]: 'lat'
            })

        return dataset