Example #1
0
def _get_crs_from_attrs(obj):
    """ Looks for attribute named `crs` containing CRS string
        1. Checks spatials coords attrs
        2. Checks data variable attrs
        3. Checks dataset attrs

        Returns
        =======
        Content for `.attrs[crs]` usually it's a string
        None if not present in any of the places listed above
    """
    if isinstance(obj, xarray.Dataset):
        if len(obj.data_vars) > 0:
            data_array = next(iter(obj.data_vars.values()))
        else:
            # fall back option
            return obj.attrs.get('crs', None)
    else:
        data_array = obj

    sdims = spatial_dims(data_array, relaxed=True)
    if sdims is not None:
        crs_set = set(data_array[d].attrs.get('crs', None) for d in sdims)
        crs = None
        if len(crs_set) > 1:
            raise ValueError('Spatial dimensions have different crs.')
        elif len(crs_set) == 1:
            crs = crs_set.pop()
    else:
        crs = None

    if crs is None:
        # fall back option
        crs = data_array.attrs.get('crs', None) or obj.attrs.get('crs', None)
    return crs
def _xarray_affine_impl(obj):
    sdims = spatial_dims(obj, relaxed=True)
    if sdims is None:
        return None, None

    yy, xx = (obj[dim] for dim in sdims)
    fallback_res = (coord.attrs.get('resolution', None) for coord in (xx, yy))

    return affine_from_axis(xx.values, yy.values, fallback_res), sdims
Example #3
0
def xr_reproject_array(
    src: xr.DataArray,
    geobox: GeoBox,
    resampling: str = "nearest",
    chunks: Optional[Tuple[int, int]] = None,
    dst_nodata: Optional[NodataType] = None,
) -> xr.DataArray:
    """
    Reproject DataArray to a given GeoBox

    :param src       : Input src[(time,) y,x (, band)]
    :param geobox    : GeoBox of the destination
    :param resampling: Resampling strategy as a string: nearest, bilinear, average, mode ...
    :param chunks    : In Y,X dimensions only, default is to use input chunk size
    :param dst_nodata: nodata marker for dst image (default is to use src.nodata)
    """
    src_nodata = getattr(src, "nodata", None)
    if dst_nodata is None:
        dst_nodata = src_nodata

    src_geobox = src.geobox
    assert src_geobox is not None

    yx_dims = spatial_dims(src)
    axis = tuple(src.dims).index(yx_dims[0])

    src_dims = tuple(src.dims)
    dst_dims = src_dims[:axis] + geobox.dims + src_dims[axis + 2:]

    coords = geobox.xr_coords(with_crs=True)

    # copy non-spatial coords from src to dst
    src_non_spatial_dims = src_dims[:axis] + src_dims[axis + 2:]
    for dim in src_non_spatial_dims:
        if dim not in coords:
            coords[dim] = src.coords[dim]

    attrs = {}
    if dst_nodata is not None:
        attrs["nodata"] = dst_nodata

    if is_dask_collection(src):
        data = dask_reproject(
            src.data,
            src_geobox,
            geobox,
            resampling=resampling,
            chunks=chunks,
            src_nodata=src_nodata,
            dst_nodata=dst_nodata,
            axis=axis,
        )
    else:
        data = _reproject_block_impl(
            src.data,
            src_geobox,
            geobox,
            resampling=resampling,
            src_nodata=src_nodata,
            dst_nodata=dst_nodata,
            axis=axis,
        )

    return xr.DataArray(data,
                        name=src.name,
                        coords=coords,
                        dims=dst_dims,
                        attrs=attrs)
Example #4
0
def rgb(ds,
        bands=['nbart_red', 'nbart_green', 'nbart_blue'],
        index=None,
        index_dim='time',
        robust=True,
        percentile_stretch=None,
        col_wrap=4,
        size=6,
        aspect=None,
        savefig_path=None,
        savefig_kwargs={},
        **kwargs):
    """
    Takes an xarray dataset and plots RGB images using three imagery 
    bands (e.g ['nbart_red', 'nbart_green', 'nbart_blue']). The `index` 
    parameter allows easily selecting individual or multiple images for 
    RGB plotting. Images can be saved to file by specifying an output 
    path using `savefig_path`.
    
    This function was designed to work as an easier-to-use wrapper 
    around xarray's `.plot.imshow()` functionality.
    
    Last modified: September 2020
    
    Parameters
    ----------  
    ds : xarray Dataset
        A two-dimensional or multi-dimensional array to plot as an RGB 
        image. If the array has more than two dimensions (e.g. multiple 
        observations along a 'time' dimension), either use `index` to 
        select one (`index=0`) or multiple observations 
        (`index=[0, 1]`), or create a custom faceted plot using e.g. 
        `col="time"`.       
    bands : list of strings, optional
        A list of three strings giving the band names to plot. Defaults 
        to '['nbart_red', 'nbart_green', 'nbart_blue']'.
    index : integer or list of integers, optional
        `index` can be used to select one (`index=0`) or multiple 
        observations (`index=[0, 1]`) from the input dataset for 
        plotting. If multiple images are requested these will be plotted
        as a faceted plot.
    index_dim : string, optional
        The dimension along which observations should be plotted if 
        multiple observations are requested using `index`. Defaults to 
        `time`.
    robust : bool, optional
        Produces an enhanced image where the colormap range is computed 
        with 2nd and 98th percentiles instead of the extreme values. 
        Defaults to True.
    percentile_stretch : tuple of floats
        An tuple of two floats (between 0.00 and 1.00) that can be used 
        to clip the colormap range to manually specified percentiles to 
        get more control over the brightness and contrast of the image. 
        The default is None; '(0.02, 0.98)' is equivelent to 
        `robust=True`. If this parameter is used, `robust` will have no 
        effect.
    col_wrap : integer, optional
        The number of columns allowed in faceted plots. Defaults to 4.
    size : integer, optional
        The height (in inches) of each plot. Defaults to 6.
    aspect : integer, optional
        Aspect ratio of each facet in the plot, so that aspect * size 
        gives width of each facet in inches. Defaults to None, which 
        will calculate the aspect based on the x and y dimensions of 
        the input data.
    savefig_path : string, optional
        Path to export image file for the RGB plot. Defaults to None, 
        which does not export an image file.
    savefig_kwargs : dict, optional
        A dict of keyword arguments to pass to 
        `matplotlib.pyplot.savefig` when exporting an image file. For 
        all available options, see: 
        https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html        
    **kwargs : optional
        Additional keyword arguments to pass to `xarray.plot.imshow()`.
        For example, the function can be used to plot into an existing
        matplotlib axes object by passing an `ax` keyword argument.
        For more options, see:
        http://xarray.pydata.org/en/stable/generated/xarray.plot.imshow.html  
        
    Returns
    -------
    An RGB plot of one or multiple observations, and optionally an image
    file written to file.
    
    """

    # Get names of x and y dims
    # TODO: remove geobox and try/except once datacube 1.8 is default
    try:
        y_dim, x_dim = ds.geobox.dimensions
    except AttributeError:
        from datacube.utils import spatial_dims
        y_dim, x_dim = spatial_dims(ds)

    # If ax is supplied via kwargs, ignore aspect and size
    if 'ax' in kwargs:

        # Create empty aspect size kwarg that will be passed to imshow
        aspect_size_kwarg = {}
    else:
        # Compute image aspect
        if not aspect:
            aspect = image_aspect(ds)

        # Populate aspect size kwarg with aspect and size data
        aspect_size_kwarg = {'aspect': aspect, 'size': size}

    # If no value is supplied for `index` (the default), plot using default
    # values and arguments passed via `**kwargs`
    if index is None:

        # Select bands and convert to DataArray
        da = ds[bands].to_array().compute()

        # If percentile_stretch == True, clip plotting to percentile vmin, vmax
        if percentile_stretch:
            vmin, vmax = da.quantile(percentile_stretch).values
            kwargs.update({'vmin': vmin, 'vmax': vmax})

        # If there are more than three dimensions and the index dimension == 1,
        # squeeze this dimension out to remove it
        if ((len(ds.dims) > 2) and ('col' not in kwargs)
                and (len(da[index_dim]) == 1)):

            da = da.squeeze(dim=index_dim)

        # If there are more than three dimensions and the index dimension
        # is longer than 1, raise exception to tell user to use 'col'/`index`
        elif ((len(ds.dims) > 2) and ('col' not in kwargs)
              and (len(da[index_dim]) > 1)):

            raise Exception(
                f'The input dataset `ds` has more than two dimensions: '
                f'{list(ds.dims.keys())}. Please select a single observation '
                'using e.g. `index=0`, or enable faceted plotting by adding '
                'the arguments e.g. `col="time", col_wrap=4` to the function '
                'call')

        img = da.plot.imshow(x=x_dim,
                             y=y_dim,
                             robust=robust,
                             col_wrap=col_wrap,
                             **aspect_size_kwarg,
                             **kwargs)

    # If values provided for `index`, extract corresponding observations and
    # plot as either single image or facet plot
    else:

        # If a float is supplied instead of an integer index, raise exception
        if isinstance(index, float):
            raise Exception(
                f'Please supply `index` as either an integer or a list of '
                'integers')

        # If col argument is supplied as well as `index`, raise exception
        if 'col' in kwargs:
            raise Exception(
                f'Cannot supply both `index` and `col`; please remove one and '
                'try again')

        # Convert index to generic type list so that number of indices supplied
        # can be computed
        index = index if isinstance(index, list) else [index]

        # Select bands and observations and convert to DataArray
        da = ds[bands].isel(**{index_dim: index}).to_array().compute()

        # If percentile_stretch == True, clip plotting to percentile vmin, vmax
        if percentile_stretch:
            vmin, vmax = da.quantile(percentile_stretch).values
            kwargs.update({'vmin': vmin, 'vmax': vmax})

        # If multiple index values are supplied, plot as a faceted plot
        if len(index) > 1:

            img = da.plot.imshow(x=x_dim,
                                 y=y_dim,
                                 robust=robust,
                                 col=index_dim,
                                 col_wrap=col_wrap,
                                 **aspect_size_kwarg,
                                 **kwargs)

        # If only one index is supplied, squeeze out index_dim and plot as a
        # single panel
        else:

            img = da.squeeze(dim=index_dim).plot.imshow(robust=robust,
                                                        **aspect_size_kwarg,
                                                        **kwargs)

    # If an export path is provided, save image to file. Individual and
    # faceted plots have a different API (figure vs fig) so we get around this
    # using a try statement:
    if savefig_path:

        print(f'Exporting image to {savefig_path}')

        try:
            img.fig.savefig(savefig_path, **savefig_kwargs)
        except:
            img.figure.savefig(savefig_path, **savefig_kwargs)