def _get_crs_from_attrs(obj): """ Looks for attribute named `crs` containing CRS string 1. Checks spatials coords attrs 2. Checks data variable attrs 3. Checks dataset attrs Returns ======= Content for `.attrs[crs]` usually it's a string None if not present in any of the places listed above """ if isinstance(obj, xarray.Dataset): if len(obj.data_vars) > 0: data_array = next(iter(obj.data_vars.values())) else: # fall back option return obj.attrs.get('crs', None) else: data_array = obj sdims = spatial_dims(data_array, relaxed=True) if sdims is not None: crs_set = set(data_array[d].attrs.get('crs', None) for d in sdims) crs = None if len(crs_set) > 1: raise ValueError('Spatial dimensions have different crs.') elif len(crs_set) == 1: crs = crs_set.pop() else: crs = None if crs is None: # fall back option crs = data_array.attrs.get('crs', None) or obj.attrs.get('crs', None) return crs
def _xarray_affine_impl(obj): sdims = spatial_dims(obj, relaxed=True) if sdims is None: return None, None yy, xx = (obj[dim] for dim in sdims) fallback_res = (coord.attrs.get('resolution', None) for coord in (xx, yy)) return affine_from_axis(xx.values, yy.values, fallback_res), sdims
def xr_reproject_array( src: xr.DataArray, geobox: GeoBox, resampling: str = "nearest", chunks: Optional[Tuple[int, int]] = None, dst_nodata: Optional[NodataType] = None, ) -> xr.DataArray: """ Reproject DataArray to a given GeoBox :param src : Input src[(time,) y,x (, band)] :param geobox : GeoBox of the destination :param resampling: Resampling strategy as a string: nearest, bilinear, average, mode ... :param chunks : In Y,X dimensions only, default is to use input chunk size :param dst_nodata: nodata marker for dst image (default is to use src.nodata) """ src_nodata = getattr(src, "nodata", None) if dst_nodata is None: dst_nodata = src_nodata src_geobox = src.geobox assert src_geobox is not None yx_dims = spatial_dims(src) axis = tuple(src.dims).index(yx_dims[0]) src_dims = tuple(src.dims) dst_dims = src_dims[:axis] + geobox.dims + src_dims[axis + 2:] coords = geobox.xr_coords(with_crs=True) # copy non-spatial coords from src to dst src_non_spatial_dims = src_dims[:axis] + src_dims[axis + 2:] for dim in src_non_spatial_dims: if dim not in coords: coords[dim] = src.coords[dim] attrs = {} if dst_nodata is not None: attrs["nodata"] = dst_nodata if is_dask_collection(src): data = dask_reproject( src.data, src_geobox, geobox, resampling=resampling, chunks=chunks, src_nodata=src_nodata, dst_nodata=dst_nodata, axis=axis, ) else: data = _reproject_block_impl( src.data, src_geobox, geobox, resampling=resampling, src_nodata=src_nodata, dst_nodata=dst_nodata, axis=axis, ) return xr.DataArray(data, name=src.name, coords=coords, dims=dst_dims, attrs=attrs)
def rgb(ds, bands=['nbart_red', 'nbart_green', 'nbart_blue'], index=None, index_dim='time', robust=True, percentile_stretch=None, col_wrap=4, size=6, aspect=None, savefig_path=None, savefig_kwargs={}, **kwargs): """ Takes an xarray dataset and plots RGB images using three imagery bands (e.g ['nbart_red', 'nbart_green', 'nbart_blue']). The `index` parameter allows easily selecting individual or multiple images for RGB plotting. Images can be saved to file by specifying an output path using `savefig_path`. This function was designed to work as an easier-to-use wrapper around xarray's `.plot.imshow()` functionality. Last modified: September 2020 Parameters ---------- ds : xarray Dataset A two-dimensional or multi-dimensional array to plot as an RGB image. If the array has more than two dimensions (e.g. multiple observations along a 'time' dimension), either use `index` to select one (`index=0`) or multiple observations (`index=[0, 1]`), or create a custom faceted plot using e.g. `col="time"`. bands : list of strings, optional A list of three strings giving the band names to plot. Defaults to '['nbart_red', 'nbart_green', 'nbart_blue']'. index : integer or list of integers, optional `index` can be used to select one (`index=0`) or multiple observations (`index=[0, 1]`) from the input dataset for plotting. If multiple images are requested these will be plotted as a faceted plot. index_dim : string, optional The dimension along which observations should be plotted if multiple observations are requested using `index`. Defaults to `time`. robust : bool, optional Produces an enhanced image where the colormap range is computed with 2nd and 98th percentiles instead of the extreme values. Defaults to True. percentile_stretch : tuple of floats An tuple of two floats (between 0.00 and 1.00) that can be used to clip the colormap range to manually specified percentiles to get more control over the brightness and contrast of the image. The default is None; '(0.02, 0.98)' is equivelent to `robust=True`. If this parameter is used, `robust` will have no effect. col_wrap : integer, optional The number of columns allowed in faceted plots. Defaults to 4. size : integer, optional The height (in inches) of each plot. Defaults to 6. aspect : integer, optional Aspect ratio of each facet in the plot, so that aspect * size gives width of each facet in inches. Defaults to None, which will calculate the aspect based on the x and y dimensions of the input data. savefig_path : string, optional Path to export image file for the RGB plot. Defaults to None, which does not export an image file. savefig_kwargs : dict, optional A dict of keyword arguments to pass to `matplotlib.pyplot.savefig` when exporting an image file. For all available options, see: https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html **kwargs : optional Additional keyword arguments to pass to `xarray.plot.imshow()`. For example, the function can be used to plot into an existing matplotlib axes object by passing an `ax` keyword argument. For more options, see: http://xarray.pydata.org/en/stable/generated/xarray.plot.imshow.html Returns ------- An RGB plot of one or multiple observations, and optionally an image file written to file. """ # Get names of x and y dims # TODO: remove geobox and try/except once datacube 1.8 is default try: y_dim, x_dim = ds.geobox.dimensions except AttributeError: from datacube.utils import spatial_dims y_dim, x_dim = spatial_dims(ds) # If ax is supplied via kwargs, ignore aspect and size if 'ax' in kwargs: # Create empty aspect size kwarg that will be passed to imshow aspect_size_kwarg = {} else: # Compute image aspect if not aspect: aspect = image_aspect(ds) # Populate aspect size kwarg with aspect and size data aspect_size_kwarg = {'aspect': aspect, 'size': size} # If no value is supplied for `index` (the default), plot using default # values and arguments passed via `**kwargs` if index is None: # Select bands and convert to DataArray da = ds[bands].to_array().compute() # If percentile_stretch == True, clip plotting to percentile vmin, vmax if percentile_stretch: vmin, vmax = da.quantile(percentile_stretch).values kwargs.update({'vmin': vmin, 'vmax': vmax}) # If there are more than three dimensions and the index dimension == 1, # squeeze this dimension out to remove it if ((len(ds.dims) > 2) and ('col' not in kwargs) and (len(da[index_dim]) == 1)): da = da.squeeze(dim=index_dim) # If there are more than three dimensions and the index dimension # is longer than 1, raise exception to tell user to use 'col'/`index` elif ((len(ds.dims) > 2) and ('col' not in kwargs) and (len(da[index_dim]) > 1)): raise Exception( f'The input dataset `ds` has more than two dimensions: ' f'{list(ds.dims.keys())}. Please select a single observation ' 'using e.g. `index=0`, or enable faceted plotting by adding ' 'the arguments e.g. `col="time", col_wrap=4` to the function ' 'call') img = da.plot.imshow(x=x_dim, y=y_dim, robust=robust, col_wrap=col_wrap, **aspect_size_kwarg, **kwargs) # If values provided for `index`, extract corresponding observations and # plot as either single image or facet plot else: # If a float is supplied instead of an integer index, raise exception if isinstance(index, float): raise Exception( f'Please supply `index` as either an integer or a list of ' 'integers') # If col argument is supplied as well as `index`, raise exception if 'col' in kwargs: raise Exception( f'Cannot supply both `index` and `col`; please remove one and ' 'try again') # Convert index to generic type list so that number of indices supplied # can be computed index = index if isinstance(index, list) else [index] # Select bands and observations and convert to DataArray da = ds[bands].isel(**{index_dim: index}).to_array().compute() # If percentile_stretch == True, clip plotting to percentile vmin, vmax if percentile_stretch: vmin, vmax = da.quantile(percentile_stretch).values kwargs.update({'vmin': vmin, 'vmax': vmax}) # If multiple index values are supplied, plot as a faceted plot if len(index) > 1: img = da.plot.imshow(x=x_dim, y=y_dim, robust=robust, col=index_dim, col_wrap=col_wrap, **aspect_size_kwarg, **kwargs) # If only one index is supplied, squeeze out index_dim and plot as a # single panel else: img = da.squeeze(dim=index_dim).plot.imshow(robust=robust, **aspect_size_kwarg, **kwargs) # If an export path is provided, save image to file. Individual and # faceted plots have a different API (figure vs fig) so we get around this # using a try statement: if savefig_path: print(f'Exporting image to {savefig_path}') try: img.fig.savefig(savefig_path, **savefig_kwargs) except: img.figure.savefig(savefig_path, **savefig_kwargs)