def save_data( self, subfld: str, predict: xr.DataArray, probabilites: xr.DataArray, geobox_used: GeoBox, ): """ save the prediction results to local folder, prepare stac json :param subfld: local subfolder to save the prediction tifs` :param predict: predicted binary class label array :param probabilites: prediction probabilities array :param geobox_used: geobox used for the features for prediction :return: None """ output_fld, paths, metadata_path = prepare_the_io_path( self.config, subfld) x, y = get_xy_from_task(subfld) if not osp.exists(output_fld): os.makedirs(output_fld) self._log.info("collecting mask and write cog.") write_cog( predict.astype(np.uint8).compute(), paths["mask"], overwrite=True, ) self._log.info("collecting prob and write cog.") write_cog( probabilites.astype(np.uint8).compute(), paths["prob"], overwrite=True, ) self._log.info("collecting the stac json and write out.") processing_dt = datetime.datetime.now() uuid_hex = uuid.uuid4() remoe_path = dict((k, osp.basename(p)) for k, p in paths.items()) remote_metadata_path = metadata_path.replace(self.config.DATA_PATH, self.config.REMOTE_PATH) stac_doc = StacIntoDc.render_metadata( self.config.product, geobox_used, (x, y), self.config.datetime_range, uuid_hex, remoe_path, remote_metadata_path, processing_dt, ) with open(metadata_path, "w") as fh: json.dump(stac_doc, fh, indent=2)
def scale_and_clip_dataarray(dataarray: xr.DataArray, *, scale_factor=1, add_offset=0, clip_range=None, valid_range=None, new_nodata=-999, new_dtype='int16'): orig_attrs = dataarray.attrs nodata = dataarray.attrs['nodata'] mask = dataarray.data == nodata # add another mask here for if data > 10000 then also make that nodata dataarray = dataarray * scale_factor + add_offset if clip_range is not None: dataarray = dataarray.clip(*clip_range) dataarray = dataarray.astype(new_dtype) dataarray.data[mask] = new_nodata if valid_range is not None: valid_min, valid_max = valid_range dataarray = dataarray.where(dataarray >= valid_min, new_nodata) dataarray = dataarray.where(dataarray <= valid_max, new_nodata) dataarray.attrs = orig_attrs dataarray.attrs['nodata'] = new_nodata return dataarray
def reapply_mask_to_boolean_xarray(self, variable_of_mask: str, da: xr.DataArray) -> xr.DataArray: """Because boolean comparisons in xarray return False for nan values we need to reapply the mask from the original `da` to mask out the sea or invalid values (for example). Arguments: --------- variable_of_mask: str the variable that you want to use in `self.ds` as the mask. The `np.nan` values in `self.ds[variable]` will be marked as `True` da: xr.DataArray the boolean DataArray (for example `self.exceedences`) to reapply the mask to Returns: ------- xr.DataArray with dtype of `int`, because `bool` dtype doesn't store masks / nan values very well. NOTE: 1. Uses the input dataset for the mask - TODO: does this make sense? """ assert da.dtype == np.dtype('bool'), f"This function \ currrently works on boolean xr.Dataset objects only" mask = get_ds_mask(self.ds[variable_of_mask]) da = da.astype(int).where(~mask) return da
def transform_single_date_data(self, data): imgdata = Dataset() for imgband, components in self.rgb_components.items(): if callable(components): imgband_data = components(data) dims = imgband_data.dims imgband_data = imgband_data.astype('uint8') imgdata[imgband] = (dims, imgband_data) else: imgband_data = None for band, intensity in components.items(): if callable(intensity): imgband_component = intensity(data[band], band, imgband) else: imgband_component = data[band] * intensity if imgband_data is not None: imgband_data += imgband_component else: imgband_data = imgband_component if imgband_data is None: imgband_data = np.zeros(list(data.dims.values()), 'uint8') imgband_data = DataArray(imgband_data, data.coords, data.dims.keys()) if imgband != "alpha": imgband_data = self.compress_band(imgband, imgband_data) imgdata[imgband] = (imgband_data.dims, imgband_data.astype("uint8")) return imgdata
def _richardson_lucy_deconv( image: xr.DataArray, iterations: int, psf: np.ndarray ) -> xr.DataArray: """ Deconvolves input image with a specified point spread function. Parameters ---------- image : xr.DataArray Input degraded image (can be N dimensional). psf : ndarray The point spread function. iterations : int Number of iterations. This parameter plays the role of regularisation. Returns ------- im_deconv : xr.DataArray The deconvolved image. """ # compute the times for direct convolution and the fft method. The fft is of # complexity O(N log(N)) for each dimension and the direct method does # straight arithmetic (and is O(n*k) to add n elements k times) direct_time = np.prod(image.shape + psf.shape) fft_time = np.sum([n * np.log(n) for n in image.shape + psf.shape]) # see whether the fourier transform convolution method or the direct # convolution method is faster (discussed in scikit-image PR #1792) time_ratio = 40.032 * fft_time / direct_time if time_ratio <= 1 or len(image.shape) > 2: convolve_method = fftconvolve else: convolve_method = convolve image = image.astype(np.float) psf = psf.astype(np.float) im_deconv = 0.5 * np.ones(image.shape) psf_mirror = psf[::-1, ::-1] eps = np.finfo(image.dtype).eps for _ in range(iterations): x = convolve_method(im_deconv, psf, 'same') np.place(x, x == 0, eps) relative_blur = image / x + eps im_deconv *= convolve_method(relative_blur, psf_mirror, 'same') if np.all(np.isnan(im_deconv)): raise RuntimeError( 'All-NaN output data detected. Likely cause is that deconvolution has been run for ' 'too many iterations.') return im_deconv
def apply_additional_transforms(source: MapSource, agg: xr.DataArray): agg = agg.astype('float64') agg.data[agg.data == 0] = np.nan for e in source.extras: if e in additional_transforms: trans = additional_transforms.get(e) if trans is not None: agg = trans(agg) else: raise ValueError(f'Invalid transform name {e}') return source, agg
def new( cls: type, data: Any, name: Optional[Name] = None, attrs: Optional[Attrs] = None, **coords, ) -> DataArray: """Create a custom DataArray from data and coordinates. {cls.doc} Args: data: Values of the DataArray. Its shape must match class ``dims``. If class ``dtype`` is defined, it will be casted to that type. If it cannot be casted, a ``ValueError`` will be raised. name: Name of the DataArray. Default is class ``name``. attrs: Attributes of the DataArray. Default is class ``attrs``. **coords: Coordinates of the DataArray defined by the class. Returns: Custom DataArray. {cls.coords.doc} """ dataarray = DataArray(data, dims=cls.dims, name=name, attrs=attrs) if cls.dtype is not None: dataarray = dataarray.astype(cls.dtype) for name, coord in cls.coords.items(): shape = [dataarray.sizes[dim] for dim in coord.dims] if name in coords: dataarray.coords[name] = coord.full(shape, coords[name]) continue if hasattr(cls, name): dataarray.coords[name] = coord.full(shape, getattr(cls, name)) continue raise ValueError( f"Default value for a coordinate {name} is not defined. " f"It must be given as a keyword argument ({name}=<value>).") return dataarray
def __init__( self, ds: xr.DataArray, aggregate_dims: list, ): self._ds = ds if (ds.dtype == np.float64) else ds.astype(np.float64) # For some reason, casting to float64 removes all attrs from the dataset self._ds.attrs = ds.attrs # array metrics self._ns_con_var = None self._ew_con_var = None self._mean = None self._mean_abs = None self._std = None self._prob_positive = None self._odds_positive = None self._prob_negative = None self._zscore = None self._mae_max = None self._corr_lag1 = None self._lag1 = None self._agg_dims = aggregate_dims self._quantile_value = None self._mean_squared = None self._root_mean_squared = None self._sum = None self._sum_squared = None self._variance = None self._quantile = 0.5 self._spre_tol = 1.0e-4 self._max_abs = None self._min_abs = None self._d_range = None self._min_val = None self._max_val = None # single value metrics self._zscore_cutoff = None self._zscore_percent_significant = None self._frame_size = 1 if aggregate_dims is not None: for dim in aggregate_dims: self._frame_size *= int(self._ds.sizes[dim])
def write_diff(foo: xr.DataArray, fn, nodata=-1, dtype=None, debug=False): dtype = dtype if dtype is not None else foo.dtype with rasterio.open( fn, 'w', driver='GTiff', height=len(foo[foo.dims[1]]), width=len(foo[foo.dims[0]]), count=1, #len(bands), dtype=dtype, crs=foo.crs, transform=foo.transform[:6], nodata=nodata, ) as dst: dst.write(foo.astype(dtype).values, 1) # Copy metadata #for attr in foo.attrs: # dst.update_tags(attr=foo.attrs[attr]) dst.update_tags(**foo.attrs) # # TODO: Can we add colormap so we don't have to do it in e.g. QGIS? (may not work w/ more than 256 levels) # dst.write_colormap(1, { # 0: (255, 0, 0, 255), # 255: (0, 0, 255, 255), # }) dst.close() if debug: tmp = xr.open_rasterio('rep/landcover-hack/bleh_diff.tif').isel( band=0, drop=True).rename('tmp') print(tmp) tmp.plot(figsize=( 10, 10 )) # The figsize is just to avoid issue w/ particular jupyter instance
def reproject_array( arr: xr.DataArray, spec: RasterSpec, interpolation: Literal["linear", "nearest"] = "nearest", fill_value: Optional[Union[int, float]] = np.nan, ) -> xr.DataArray: """ Reproject and clip a `~xarray.DataArray` to a new `~.RasterSpec` (CRS, resolution, bounds). This interpolates using `xarray.DataArray.interp`, which uses `scipy.interpolate.interpn` internally (no GDAL). It is somewhat dask-friendly, in that it at least doesn't trigger immediate computation on the array, but it's sub-optimal: the ``x`` and ``y`` dimensions are just merged into a single chunk, then interpolated. Since this both eliminates spatial parallelism, and potentially requires significant amounts of memory, `reproject_array` is only recommended on arrays with a relatively small number spatial chunks. Warning ------- This method is very slow on large arrays due to inefficiencies in generating the dask graphs. Additionally, all spatial chunking is lost. Parameters ---------- arr: Array to reproject. It must have ``epsg``, ``x``, and ``y`` coordinates. The ``x`` and ``y`` coordinates are assumed to indicate the top-left corner of each pixel, not the center. spec: The `~.RasterSpec` to reproject to. interpolation: Interpolation method: ``"linear"`` or ``"nearest"``, default ``"nearest"``. fill_value: Fill output pixels that fall outside the bounds of ``arr`` with this value (default NaN). Returns ------- xarray.DataArray: The clipped and reprojected array. """ # TODO this scipy/`interp`-based approach still isn't block-parallel # (seems like xarray just rechunks to fuse all the spatial chunks first), # so this both won't scale, and can be crazy slow in dask graph construction # (and the rechunk probably eliminates any hope of sending an HLG to the scheduler). from_epsg = array_epsg(arr) if ( from_epsg == spec.epsg and array_bounds(arr) == spec.bounds and arr.shape[:-2] == spec.shape ): return arr as_bool = False if arr.dtype.kind == "b": # `interp` can't handle boolean arrays arr = arr.astype("uint8") as_bool = True # TODO fastpath when there's no overlap? (graph shouldn't have any IO in it.) # Or does that already happen? # TODO support `xy_coords="center"`? `spec` assumes topleft; # if the x/y coords on the array are center, this will be a half # pixel off. minx, miny, maxx, maxy = spec.bounds height, width = spec.shape x = np.linspace(minx, maxx, width, endpoint=False) y = np.linspace(maxy, miny, height, endpoint=False) if from_epsg == spec.epsg: # Simpler case: just interpolate within the same CRS result = arr.interp( x=x, y=y, method=interpolation, kwargs=dict(fill_value=fill_value) ) return result.astype(bool) if as_bool else result # Different CRSs: need to do a 2D interpolation. # We do this by, for each point in the output grid, generating # the coordinates in the _input_ CRS that correspond to that point. reverse_transformer = cached_transformer( spec.epsg, from_epsg, skip_equivalent=True, always_xy=True ) xs, ys = np.meshgrid(x, y, copy=False) src_xs, src_ys = reverse_transformer.transform(xs, ys, errcheck=True) xs_indexer = xr.DataArray(src_xs, dims=["y", "x"], coords=dict(y=y, x=x)) ys_indexer = xr.DataArray(src_ys, dims=["y", "x"], coords=dict(y=y, x=x)) # TODO maybe just drop old dims instead? old_xdim = f"x_{from_epsg}" old_ydim = f"y_{from_epsg}" result = arr.rename(x=old_xdim, y=old_ydim).interp( {old_xdim: xs_indexer, old_ydim: ys_indexer}, method=interpolation, kwargs=dict(fill_value=fill_value), ) return result.astype(bool) if as_bool else result
def asarray(x, target, dims=None): from xarray import DataArray from limix._bits.dask import is_dataframe as is_dask_dataframe from limix._bits.dask import is_array as is_dask_array from limix._bits.dask import is_series as is_dask_series from limix._bits.dask import array_shape_reveal from limix._bits.xarray import is_dataarray from ._conf import CONF from numpy import issubdtype, integer if target not in CONF["targets"]: raise ValueError(f"Unknown target name: {target}.") import dask.array as da import xarray as xr if is_dask_dataframe(x) or is_dask_series(x): xidx = x.index.compute() x = da.asarray(x) x = array_shape_reveal(x) x0 = xr.DataArray(x) x0.coords[x0.dims[0]] = xidx if is_dask_dataframe(x): x0.coords[x0.dims[1]] = x.columns x = x0 elif is_dask_array(x): x = array_shape_reveal(x) x = xr.DataArray(x) if not is_dataarray(x): x = DataArray(x) x.name = target while x.ndim < 2: rdims = set(CONF["data_dims"]["trait"]).intersection( set(x.coords.keys())) rdims = rdims - set(x.dims) if len(rdims) == 1: dim = rdims.pop() else: dim = "dim_{}".format(x.ndim) x = x.expand_dims(dim, x.ndim) if isinstance(dims, (tuple, list)): dims = {a: n for a, n in enumerate(dims)} dims = _numbered_axes(dims) if len(set(dims.values())) < len(dims.values()): raise ValueError("`dims` must not contain duplicated values.") x = x.rename({x.dims[axis]: name for axis, name in dims.items()}) x = _set_missing_dim(x, CONF["data_dims"][target]) x = x.transpose(*CONF["data_dims"][target]) if issubdtype(x.dtype, integer): x = x.astype(float) for dim in x.dims: if x.coords[dim].dtype.kind in {"U", "S"}: x.coords[dim].values = x.coords[dim].values.astype(object) return x
def run_bounds(mask: xr.DataArray, dim: str = "time", coord: Optional[Union[bool, str]] = True): """Return the start and end dates of boolean runs along a dimension. Parameters ---------- mask : xr.DataArray Boolean array. dim : str Dimension along which to look for runs. coord : bool or str If True, return values of the coordinate, if a string, returns values from `dim.dt.<coord>`. if False, return indexes. Returns ------- xr.DataArray With ``dim`` reduced to "events" and "bounds". The events dim is as long as needed, padded with NaN or NaT. """ if isinstance(mask.data, dsk.Array): raise NotImplementedError( "Dask arrays not supported as we can't know the final event number before computing." ) diff = xr.concat((mask.isel({ dim: 0 }).astype(int), mask.astype(int).diff(dim)), dim) nstarts = (diff == 1).sum(dim).max().item() def _get_indices(arr, *, N): out = np.full((N, ), np.nan, dtype=float) inds = np.where(arr)[0] out[:len(inds)] = inds return out starts = xr.apply_ufunc( _get_indices, diff == 1, input_core_dims=[[dim]], output_core_dims=[["events"]], kwargs={"N": nstarts}, vectorize=True, ) ends = xr.apply_ufunc( _get_indices, diff == -1, input_core_dims=[[dim]], output_core_dims=[["events"]], kwargs={"N": nstarts}, vectorize=True, ) if coord: crd = mask[dim] if isinstance(coord, str): crd = getattr(crd.dt, coord) starts = lazy_indexing(crd, starts).drop(dim) ends = lazy_indexing(crd, ends).drop(dim) return xr.concat((starts, ends), "bounds")
def first_run( da: xr.DataArray, window: int, dim: str = "time", coord: Optional[Union[str, bool]] = False, ufunc_1dim: Union[str, bool] = "auto", ): """Return the index of the first item of the first run of at least a given length. Parameters ---------- da : xr.DataArray Input N-dimensional DataArray (boolean) window : int Minimum duration of consecutive run to accumulate values. dim : str Dimension along which to calculate consecutive run (default: 'time'). coord : Optional[str] If not False, the function returns values along `dim` instead of indexes. If `dim` has a datetime dtype, `coord` can also be a str of the name of the DateTimeAccessor object to use (ex: 'dayofyear'). ufunc_1dim : Union[str, bool] Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal usage based on number of data points. Using 1D_ufunc=True is typically more efficient for dataarray with a small number of gridpoints. Returns ------- out : xr.DataArray Index (or coordinate if `coord` is not False) of first item in first valid run. Returns np.nan if there are no valid run. """ if ufunc_1dim == "auto": if isinstance(da.data, dsk.Array) and len(da.chunks[da.dims.index(dim)]) > 1: ufunc_1dim = False else: npts = get_npts(da) ufunc_1dim = npts <= npts_opt da = da.fillna( 0) # We expect a boolean array, but there could be NaNs nonetheless if ufunc_1dim: out = first_run_ufunc(x=da, window=window, dim=dim) else: da = da.astype("int") i = xr.DataArray(np.arange(da[dim].size), dims=dim) ind = xr.broadcast(i, da)[0].transpose(*da.dims) if isinstance(da.data, dsk.Array): ind = ind.chunk(da.chunks) wind_sum = da.rolling(time=window).sum(allow_lazy=True, skipna=False) out = ind.where(wind_sum >= window).min(dim=dim) - ( window - 1 ) # remove window -1 as rolling result index is last element of the moving window if coord: crd = da[dim] if isinstance(coord, str): crd = getattr(crd.dt, coord) out = lazy_indexing(crd, out) if dim in out: out = out.drop_vars(dim) return out
def splev( x_new: xarray.DataArray, tck: xarray.Dataset, extrapolate: bool | str = True ) -> xarray.DataArray: """Evaluate the B-spline generated with :func:`splrep`. :param x_new: Any :class:`~xarray.DataArray` with any number of dims, not necessarily the original interpolation dim. Alternatively, it can be any 1-dimensional array-like; it will be automatically converted to a :class:`~xarray.DataArray` on the interpolation dim. :param xarray.Dataset tck: As returned by :func:`splrep`. It can have been: - transposed (not recommended, as performance will drop if c is not C-contiguous) - sliced, reordered, or (re)chunked, on any dim except the interpolation dim - computed from dask to numpy backend - round-tripped to disk :param extrapolate: True Extrapolate the first and last polynomial pieces of b-spline functions active on the base interval False Return NaNs outside of the base interval 'periodic' Periodic extrapolation is used 'clip' Return y[0] and y[-1] outside of the base interval :returns: :class:`~xarray.DataArray` with all dims of the interpolated array, minus the interpolation dim, plus all dims of x_new See :func:`splrep` for usage example. """ # Pre-process x_new into a DataArray if not isinstance(x_new, xarray.DataArray): if not isinstance(x_new, dask_array_type): x_new = np.array(x_new) if x_new.ndim == 0: dims = [] elif x_new.ndim == 1: dims = [tck.spline_dim] else: raise ValueError( "N-dimensional x_new is only supported if " "x_new is a DataArray" ) x_new = xarray.DataArray(x_new, dims=dims, coords={tck.spline_dim: x_new}) dim = tck.spline_dim t = tck.t c = tck.c k = tck.k invalid_dims = {*x_new.dims} & {*c.dims} - {dim} if invalid_dims: raise ValueError( "Overlapping dims between interpolated " "array and x_new: %s" % ",".join(str(d) for d in invalid_dims) ) if t.shape != (c.sizes[dim] + k + 1,): raise ValueError("Interpolated dimension has been sliced") if x_new.dtype.kind == "M": # Note that we're modifying the x_new values, not the x_new coords # xarray datetime objects are always in ns x_new = x_new.astype(float) if extrapolate == "clip": x = tck.coords[dim].values if x.dtype.kind == "M": x = x.astype("M8[ns]").astype(float) x_new = np.clip(x_new, x[0].tolist(), x[-1].tolist()) extrapolate = False if c.dims[0] != dim: c = c.transpose(dim, *[d for d in c.dims if d != dim]) if any(isinstance(v.data, dask_array_type) for v in (x_new, t, c)): if t.chunks and len(t.chunks[0]) > 1: raise NotImplementedError( "Unsupported: multiple chunks on interpolation dim" ) if c.chunks and len(c.chunks[0]) > 1: raise NotImplementedError( "Unsupported: multiple chunks on interpolation dim" ) # omitting t and c x_new_axes = "abdefghijklm"[: x_new.ndim] c_axes = "nopqrsuvwxyz"[: c.ndim - 1] y_new = da.blockwise( kernels.splev, x_new_axes + c_axes, x_new.data, x_new_axes, t.data, "t", c.data, "c" + c_axes, k=k, extrapolate=extrapolate, concatenate=True, dtype=float, ) else: y_new = kernels.splev( x_new.values, t.values, c.values, k, extrapolate=extrapolate ) y_new = xarray.DataArray(y_new, dims=x_new.dims + c.dims[1:], coords=x_new.coords) y_new.coords.update({k: c for k, c in c.coords.items() if dim not in c.dims}) return y_new
class TestXarrayBilinear(unittest.TestCase): """Test Xarra/Dask -based bilinear interpolation.""" def setUp(self): """Do some setup for common things.""" import dask.array as da from xarray import DataArray from pyresample import geometry, kd_tree self.pts_irregular = (np.array([ [-1., 1.], ]), np.array([ [1., 2.], ]), np.array([ [-2., -1.], ]), np.array([ [2., -4.], ])) self.pts_vert_parallel = (np.array([ [-1., 1.], ]), np.array([ [1., 2.], ]), np.array([ [-1., -1.], ]), np.array([ [1., -2.], ])) self.pts_both_parallel = (np.array([ [-1., 1.], ]), np.array([ [1., 1.], ]), np.array([ [-1., -1.], ]), np.array([ [1., -1.], ])) # Area definition with four pixels self.target_def = geometry.AreaDefinition( 'areaD', 'Europe (3km, HRV, VTC)', 'areaD', { 'a': '6378144.0', 'b': '6356759.0', 'lat_0': '50.00', 'lat_ts': '50.00', 'lon_0': '8.00', 'proj': 'stere' }, 4, 4, [-1370912.72, -909968.64000000001, 1029087.28, 1490031.3600000001]) # Input data around the target pixel at 0.63388324, 55.08234642, in_shape = (100, 100) self.data1 = DataArray(da.ones((in_shape[0], in_shape[1])), dims=('y', 'x')) self.data2 = 2. * self.data1 self.data3 = self.data1 + 9.5 lons, lats = np.meshgrid(np.linspace(-25., 40., num=in_shape[0]), np.linspace(45., 75., num=in_shape[1])) self.source_def = geometry.SwathDefinition(lons=lons, lats=lats) self.radius = 50e3 self.neighbours = 32 valid_input_index, output_idxs, index_array, dists = \ kd_tree.get_neighbour_info(self.source_def, self.target_def, self.radius, neighbours=self.neighbours, nprocs=1) input_size = valid_input_index.sum() index_mask = (index_array == input_size) index_array = np.where(index_mask, 0, index_array) self.valid_input_index = valid_input_index self.index_array = index_array shp = self.source_def.shape self.cols, self.lines = np.meshgrid(np.arange(shp[1]), np.arange(shp[0])) def test_init(self): """Test that the resampler has been initialized correctly.""" from pyresample.bilinear.xarr import XArrayResamplerBilinear # With defaults resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) self.assertTrue(resampler.source_geo_def == self.source_def) self.assertTrue(resampler.target_geo_def == self.target_def) self.assertEqual(resampler.radius_of_influence, self.radius) self.assertEqual(resampler.neighbours, 32) self.assertEqual(resampler.epsilon, 0) self.assertTrue(resampler.reduce_data) # These should be None self.assertIsNone(resampler.valid_input_index) self.assertIsNone(resampler.valid_output_index) self.assertIsNone(resampler.index_array) self.assertIsNone(resampler.distance_array) self.assertIsNone(resampler.bilinear_t) self.assertIsNone(resampler.bilinear_s) self.assertIsNone(resampler.slices_x) self.assertIsNone(resampler.slices_y) self.assertIsNone(resampler.mask_slices) self.assertIsNone(resampler.out_coords_x) self.assertIsNone(resampler.out_coords_y) # self.slices_{x,y} are used in self.slices dict self.assertTrue(resampler.slices['x'] is resampler.slices_x) self.assertTrue(resampler.slices['y'] is resampler.slices_y) # self.out_coords_{x,y} are used in self.out_coords dict self.assertTrue(resampler.out_coords['x'] is resampler.out_coords_x) self.assertTrue(resampler.out_coords['y'] is resampler.out_coords_y) # Override defaults resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius, neighbours=16, epsilon=0.1, reduce_data=False) self.assertEqual(resampler.neighbours, 16) self.assertEqual(resampler.epsilon, 0.1) self.assertFalse(resampler.reduce_data) def test_get_bil_info(self): """Test calculation of bilinear info.""" from pyresample.bilinear.xarr import XArrayResamplerBilinear def _check_ts(t__, s__, nans): for i, _ in enumerate(t__): # Just check the exact value for one pixel if i == 5: self.assertAlmostEqual(t__[i], 0.730659147133, 5) self.assertAlmostEqual(s__[i], 0.310314173004, 5) # These pixels are outside the area elif i in nans: self.assertTrue(np.isnan(t__[i])) self.assertTrue(np.isnan(s__[i])) # All the others should have values between 0.0 and 1.0 else: self.assertTrue(t__[i] >= 0.0) self.assertTrue(s__[i] >= 0.0) self.assertTrue(t__[i] <= 1.0) self.assertTrue(s__[i] <= 1.0) # Data reduction enabled (default) resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius, reduce_data=True) (t__, s__, slices, mask_slices, out_coords) = resampler.get_bil_info() _check_ts(t__.compute(), s__.compute(), [3, 10, 12, 13, 14, 15]) # Nothing should be masked based on coordinates self.assertTrue(np.all(~mask_slices)) # Four values per output location self.assertEqual(mask_slices.shape, (self.target_def.size, 4)) # self.slices_{x,y} are used in self.slices dict so they # should be the same (object) self.assertTrue(isinstance(slices, dict)) self.assertTrue(resampler.slices['x'] is resampler.slices_x) self.assertTrue(np.all(resampler.slices['x'] == slices['x'])) self.assertTrue(resampler.slices['y'] is resampler.slices_y) self.assertTrue(np.all(resampler.slices['y'] == slices['y'])) # self.slices_{x,y} are used in self.slices dict so they # should be the same (object) self.assertTrue(isinstance(out_coords, dict)) self.assertTrue(resampler.out_coords['x'] is resampler.out_coords_x) self.assertTrue(np.all(resampler.out_coords['x'] == out_coords['x'])) self.assertTrue(resampler.out_coords['y'] is resampler.out_coords_y) self.assertTrue(np.all(resampler.out_coords['y'] == out_coords['y'])) # Also some other attributes should have been set self.assertTrue(t__ is resampler.bilinear_t) self.assertTrue(s__ is resampler.bilinear_s) self.assertIsNotNone(resampler.valid_output_index) self.assertIsNotNone(resampler.index_array) self.assertIsNotNone(resampler.valid_input_index) # Data reduction disabled resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius, reduce_data=False) (t__, s__, slices, mask_slices, out_coords) = resampler.get_bil_info() _check_ts(t__.compute(), s__.compute(), [10, 12, 13, 14, 15]) def test_get_sample_from_bil_info(self): """Test bilinear interpolation as a whole.""" from pyresample.bilinear.xarr import XArrayResamplerBilinear resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) _ = resampler.get_bil_info() # Sample from data1 res = resampler.get_sample_from_bil_info(self.data1) res = res.compute() # Check couple of values self.assertEqual(res.values[1, 1], 1.) self.assertTrue(np.isnan(res.values[0, 3])) # Check that the values haven't gone down or up a lot self.assertAlmostEqual(np.nanmin(res.values), 1.) self.assertAlmostEqual(np.nanmax(res.values), 1.) # Check that dimensions are the same self.assertEqual(res.dims, self.data1.dims) # Sample from data1, custom fill value res = resampler.get_sample_from_bil_info(self.data1, fill_value=-1.0) res = res.compute() self.assertEqual(np.nanmin(res.values), -1.) # Sample from integer data res = resampler.get_sample_from_bil_info(self.data1.astype(np.uint8), fill_value=None) res = res.compute() # Five values should be filled with zeros, which is the # default fill_value for integer data self.assertEqual(np.sum(res == 0), 6) @mock.patch('pyresample.bilinear.xarr.setattr') def test_compute_indices(self, mock_setattr): """Test running .compute() for indices.""" from pyresample.bilinear.xarr import (XArrayResamplerBilinear, CACHE_INDICES) resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) # Set indices to Numpy arrays for idx in CACHE_INDICES: setattr(resampler, idx, np.array([])) resampler._compute_indices() # None of the indices shouldn't have been reassigned mock_setattr.assert_not_called() # Set indices to a Mock object arr = mock.MagicMock() for idx in CACHE_INDICES: setattr(resampler, idx, arr) resampler._compute_indices() # All the indices should have been reassigned self.assertEqual(mock_setattr.call_count, len(CACHE_INDICES)) # The compute should have been called the same amount of times self.assertEqual(arr.compute.call_count, len(CACHE_INDICES)) def test_add_missing_coordinates(self): """Test coordinate updating.""" import dask.array as da from xarray import DataArray from pyresample.bilinear.xarr import XArrayResamplerBilinear resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) bands = ['R', 'G', 'B'] data = DataArray(da.ones((3, 10, 10)), dims=('bands', 'y', 'x'), coords={ 'bands': bands, 'y': np.arange(10), 'x': np.arange(10) }) resampler._add_missing_coordinates(data) # X and Y coordinates should not change self.assertIsNone(resampler.out_coords_x) self.assertIsNone(resampler.out_coords_y) self.assertIsNone(resampler.out_coords['x']) self.assertIsNone(resampler.out_coords['y']) self.assertTrue('bands' in resampler.out_coords) self.assertTrue(np.all(resampler.out_coords['bands'] == bands)) def test_slice_data(self): """Test slicing the data.""" import dask.array as da from xarray import DataArray from pyresample.bilinear.xarr import XArrayResamplerBilinear resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) # Too many dimensions data = DataArray(da.ones((1, 3, 10, 10))) with self.assertRaises(ValueError): _ = resampler._slice_data(data, np.nan) # 2D data data = DataArray(da.ones((10, 10))) resampler.slices_x = np.random.randint(0, 10, (100, 4)) resampler.slices_y = np.random.randint(0, 10, (100, 4)) resampler.mask_slices = np.zeros((100, 4), dtype=np.bool) p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan) self.assertEqual(p_1.shape, (100, )) self.assertTrue(p_1.shape == p_2.shape == p_3.shape == p_4.shape) self.assertTrue( np.all(p_1 == 1.0) and np.all(p_2 == 1.0) and np.all(p_3 == 1.0) and np.all(p_4 == 1.0)) # 2D data with masking resampler.mask_slices = np.ones((100, 4), dtype=np.bool) p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan) self.assertTrue( np.all(np.isnan(p_1)) and np.all(np.isnan(p_2)) and np.all(np.isnan(p_3)) and np.all(np.isnan(p_4))) # 3D data data = DataArray(da.ones((3, 10, 10))) resampler.slices_x = np.random.randint(0, 10, (100, 4)) resampler.slices_y = np.random.randint(0, 10, (100, 4)) resampler.mask_slices = np.zeros((100, 4), dtype=np.bool) p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan) self.assertEqual(p_1.shape, (3, 100)) self.assertTrue(p_1.shape == p_2.shape == p_3.shape == p_4.shape) # 3D data with masking resampler.mask_slices = np.ones((100, 4), dtype=np.bool) p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan) self.assertTrue( np.all(np.isnan(p_1)) and np.all(np.isnan(p_2)) and np.all(np.isnan(p_3)) and np.all(np.isnan(p_4))) @mock.patch('pyresample.bilinear.xarr.np.meshgrid') def test_get_slices(self, meshgrid): """Test slice array creation.""" from pyresample.bilinear.xarr import XArrayResamplerBilinear meshgrid.return_value = (self.cols, self.lines) resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) resampler.valid_input_index = self.valid_input_index resampler.index_array = self.index_array resampler._get_slices() self.assertIsNotNone(resampler.out_coords_x) self.assertIsNotNone(resampler.out_coords_y) self.assertTrue(resampler.out_coords_x is resampler.out_coords['x']) self.assertTrue(resampler.out_coords_y is resampler.out_coords['y']) self.assertTrue( np.allclose(resampler.out_coords_x, [-1070912.72, -470912.72, 129087.28, 729087.28])) self.assertTrue( np.allclose(resampler.out_coords_y, [1190031.36, 590031.36, -9968.64, -609968.64])) self.assertIsNotNone(resampler.slices_x) self.assertIsNotNone(resampler.slices_y) self.assertTrue(resampler.slices_x is resampler.slices['x']) self.assertTrue(resampler.slices_y is resampler.slices['y']) self.assertTrue(resampler.slices_x.shape == (self.target_def.size, 32)) self.assertTrue(resampler.slices_y.shape == (self.target_def.size, 32)) self.assertEqual(np.sum(resampler.slices_x), 12471) self.assertEqual(np.sum(resampler.slices_y), 2223) self.assertFalse(np.any(resampler.mask_slices)) # Ensure that source geo def is used in masking # Setting target_geo_def to 0-size shouldn't cause any masked values resampler.target_geo_def = np.array([]) resampler._get_slices() self.assertFalse(np.any(resampler.mask_slices)) # Setting source area def to 0-size should mask all values resampler.source_geo_def = np.array([[]]) resampler._get_slices() self.assertTrue(np.all(resampler.mask_slices)) @mock.patch('pyresample.bilinear.xarr.KDTree') def test_create_resample_kdtree(self, KDTree): """Test that KDTree creation is called.""" from pyresample.bilinear.xarr import XArrayResamplerBilinear resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) vii, kdtree = resampler._create_resample_kdtree() self.assertEqual(np.sum(vii), 2700) self.assertEqual(vii.size, self.source_def.size) KDTree.assert_called_once() @mock.patch('pyresample.bilinear.xarr.query_no_distance') def test_query_resample_kdtree(self, qnd): """Test that query_no_distance is called in _query_resample_kdtree().""" from pyresample.bilinear.xarr import XArrayResamplerBilinear resampler = XArrayResamplerBilinear(self.source_def, self.target_def, self.radius) res, none = resampler._query_resample_kdtree(1, 2, 3, 4, reduce_data=5) qnd.assert_called_with(2, 3, 4, 1, resampler.neighbours, resampler.epsilon, resampler.radius_of_influence) def test_get_input_xy_dask(self): """Test computation of input X and Y coordinates in target proj.""" import dask.array as da from pyresample.bilinear.xarr import _get_input_xy_dask from pyresample._spatial_mp import Proj proj = Proj(self.target_def.proj_str) in_x, in_y = _get_input_xy_dask(self.source_def, proj, da.from_array(self.valid_input_index), da.from_array(self.index_array)) self.assertTrue(in_x.shape, (self.target_def.size, 32)) self.assertTrue(in_y.shape, (self.target_def.size, 32)) self.assertTrue(in_x.all()) self.assertTrue(in_y.all()) def test_mask_coordinates_dask(self): """Test masking of invalid coordinates.""" import dask.array as da from pyresample.bilinear.xarr import _mask_coordinates_dask lons, lats = _mask_coordinates_dask( da.from_array([-200., 0., 0., 0., 200.]), da.from_array([0., -100., 0, 100., 0.])) lons, lats = da.compute(lons, lats) self.assertTrue(lons[2] == lats[2] == 0.0) self.assertEqual(np.sum(np.isnan(lons)), 4) self.assertEqual(np.sum(np.isnan(lats)), 4) def test_get_bounding_corners_dask(self): """Test finding surrounding bounding corners.""" import dask.array as da from pyresample.bilinear.xarr import (_get_input_xy_dask, _get_bounding_corners_dask) from pyresample._spatial_mp import Proj from pyresample import CHUNK_SIZE proj = Proj(self.target_def.proj_str) out_x, out_y = self.target_def.get_proj_coords(chunks=CHUNK_SIZE) out_x = da.ravel(out_x) out_y = da.ravel(out_y) in_x, in_y = _get_input_xy_dask(self.source_def, proj, da.from_array(self.valid_input_index), da.from_array(self.index_array)) pt_1, pt_2, pt_3, pt_4, ia_ = _get_bounding_corners_dask( in_x, in_y, out_x, out_y, self.neighbours, da.from_array(self.index_array)) self.assertTrue(pt_1.shape == pt_2.shape == pt_3.shape == pt_4.shape == (self.target_def.size, 2)) self.assertTrue(ia_.shape == (self.target_def.size, 4)) # Check which of the locations has four valid X/Y pairs by # finding where there are non-NaN values res = da.sum(pt_1 + pt_2 + pt_3 + pt_4, axis=1).compute() self.assertEqual(np.sum(~np.isnan(res)), 10) def test_get_corner_dask(self): """Test finding the closest corners.""" import dask.array as da from pyresample.bilinear.xarr import (_get_corner_dask, _get_input_xy_dask) from pyresample import CHUNK_SIZE from pyresample._spatial_mp import Proj proj = Proj(self.target_def.proj_str) in_x, in_y = _get_input_xy_dask(self.source_def, proj, da.from_array(self.valid_input_index), da.from_array(self.index_array)) out_x, out_y = self.target_def.get_proj_coords(chunks=CHUNK_SIZE) out_x = da.ravel(out_x) out_y = da.ravel(out_y) # Some copy&paste from the code to get the input out_x_tile = np.reshape(np.tile(out_x, self.neighbours), (self.neighbours, out_x.size)).T out_y_tile = np.reshape(np.tile(out_y, self.neighbours), (self.neighbours, out_y.size)).T x_diff = out_x_tile - in_x y_diff = out_y_tile - in_y stride = np.arange(x_diff.shape[0]) # Use lower left source pixels for testing valid = (x_diff > 0) & (y_diff > 0) x_3, y_3, idx_3 = _get_corner_dask(stride, valid, in_x, in_y, da.from_array(self.index_array)) self.assertTrue( x_3.shape == y_3.shape == idx_3.shape == (self.target_def.size, )) # Four locations have no data to the lower left of them (the # bottom row of the area self.assertEqual(np.sum(np.isnan(x_3.compute())), 4) @mock.patch('pyresample.bilinear.xarr._get_ts_parallellogram_dask') @mock.patch('pyresample.bilinear.xarr._get_ts_uprights_parallel_dask') @mock.patch('pyresample.bilinear.xarr._get_ts_irregular_dask') def test_get_ts_dask(self, irregular, uprights, parallellogram): """Test that the three separate functions are called.""" from pyresample.bilinear.xarr import _get_ts_dask # All valid values t_irr = np.array([0.1, 0.2, 0.3]) s_irr = np.array([0.1, 0.2, 0.3]) irregular.return_value = (t_irr, s_irr) t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6) irregular.assert_called_once() uprights.assert_not_called() parallellogram.assert_not_called() self.assertTrue(np.allclose(t__.compute(), t_irr)) self.assertTrue(np.allclose(s__.compute(), s_irr)) # NaN in the first step, good value for that location from the # second step t_irr = np.array([0.1, 0.2, np.nan]) s_irr = np.array([0.1, 0.2, np.nan]) irregular.return_value = (t_irr, s_irr) t_upr = np.array([3, 3, 0.3]) s_upr = np.array([3, 3, 0.3]) uprights.return_value = (t_upr, s_upr) t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6) self.assertEqual(irregular.call_count, 2) uprights.assert_called_once() parallellogram.assert_not_called() # Only the last value of the first step should have been replaced t_res = np.array([0.1, 0.2, 0.3]) s_res = np.array([0.1, 0.2, 0.3]) self.assertTrue(np.allclose(t__.compute(), t_res)) self.assertTrue(np.allclose(s__.compute(), s_res)) # Two NaNs in the first step, one of which are found by the # second, and the last bad value is replaced by the third step t_irr = np.array([0.1, np.nan, np.nan]) s_irr = np.array([0.1, np.nan, np.nan]) irregular.return_value = (t_irr, s_irr) t_upr = np.array([3, np.nan, 0.3]) s_upr = np.array([3, np.nan, 0.3]) uprights.return_value = (t_upr, s_upr) t_par = np.array([4, 0.2, 0.3]) s_par = np.array([4, 0.2, 0.3]) parallellogram.return_value = (t_par, s_par) t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6) self.assertEqual(irregular.call_count, 3) self.assertEqual(uprights.call_count, 2) parallellogram.assert_called_once() # Only the last two values should have been replaced t_res = np.array([0.1, 0.2, 0.3]) s_res = np.array([0.1, 0.2, 0.3]) self.assertTrue(np.allclose(t__.compute(), t_res)) self.assertTrue(np.allclose(s__.compute(), s_res)) # Too large and small values should be set to NaN t_irr = np.array([1.00001, -0.00001, 1e6]) s_irr = np.array([1.00001, -0.00001, -1e6]) irregular.return_value = (t_irr, s_irr) # Second step also returns invalid values t_upr = np.array([1.00001, 0.2, np.nan]) s_upr = np.array([-0.00001, 0.2, np.nan]) uprights.return_value = (t_upr, s_upr) # Third step has one new valid value, the last will stay invalid t_par = np.array([0.1, 0.2, 4.0]) s_par = np.array([0.1, 0.2, 4.0]) parallellogram.return_value = (t_par, s_par) t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6) t_res = np.array([0.1, 0.2, np.nan]) s_res = np.array([0.1, 0.2, np.nan]) self.assertTrue(np.allclose(t__.compute(), t_res, equal_nan=True)) self.assertTrue(np.allclose(s__.compute(), s_res, equal_nan=True)) def test_get_ts_irregular_dask(self): """Test calculations for irregular corner locations.""" from pyresample.bilinear.xarr import _get_ts_irregular_dask res = _get_ts_irregular_dask(self.pts_irregular[0], self.pts_irregular[1], self.pts_irregular[2], self.pts_irregular[3], 0., 0.) self.assertEqual(res[0], 0.375) self.assertEqual(res[1], 0.5) res = _get_ts_irregular_dask(self.pts_vert_parallel[0], self.pts_vert_parallel[1], self.pts_vert_parallel[2], self.pts_vert_parallel[3], 0., 0.) self.assertTrue(np.isnan(res[0])) self.assertTrue(np.isnan(res[1])) def test_get_ts_uprights_parallel(self): """Test calculation when uprights are parallel.""" from pyresample.bilinear import _get_ts_uprights_parallel res = _get_ts_uprights_parallel(self.pts_vert_parallel[0], self.pts_vert_parallel[1], self.pts_vert_parallel[2], self.pts_vert_parallel[3], 0., 0.) self.assertEqual(res[0], 0.5) self.assertEqual(res[1], 0.5) def test_get_ts_parallellogram(self): """Test calculation when the corners form a parallellogram.""" from pyresample.bilinear import _get_ts_parallellogram res = _get_ts_parallellogram(self.pts_both_parallel[0], self.pts_both_parallel[1], self.pts_both_parallel[2], 0., 0.) self.assertEqual(res[0], 0.5) self.assertEqual(res[1], 0.5) def test_calc_abc(self): """Test calculation of quadratic coefficients.""" from pyresample.bilinear.xarr import _calc_abc_dask # No np.nan inputs pt_1, pt_2, pt_3, pt_4 = self.pts_irregular res = _calc_abc_dask(pt_1, pt_2, pt_3, pt_4, 0.0, 0.0) self.assertFalse(np.isnan(res[0])) self.assertFalse(np.isnan(res[1])) self.assertFalse(np.isnan(res[2])) # np.nan input -> np.nan output res = _calc_abc_dask(np.array([[np.nan, np.nan]]), pt_2, pt_3, pt_4, 0.0, 0.0) self.assertTrue(np.isnan(res[0])) self.assertTrue(np.isnan(res[1])) self.assertTrue(np.isnan(res[2])) def test_solve_quadratic(self): """Test solving quadratic equation.""" from pyresample.bilinear.xarr import (_solve_quadratic_dask, _calc_abc_dask) res = _solve_quadratic_dask(1, 0, 0).compute() self.assertEqual(res, 0.0) res = _solve_quadratic_dask(1, 2, 1).compute() self.assertTrue(np.isnan(res)) res = _solve_quadratic_dask(1, 2, 1, min_val=-2.).compute() self.assertEqual(res, -1.0) # Test that small adjustments work pt_1, pt_2, pt_3, pt_4 = self.pts_vert_parallel pt_1 = self.pts_vert_parallel[0].copy() pt_1[0][0] += 1e-7 res = _calc_abc_dask(pt_1, pt_2, pt_3, pt_4, 0.0, 0.0) res = _solve_quadratic_dask(res[0], res[1], res[2]).compute() self.assertAlmostEqual(res[0], 0.5, 5) res = _calc_abc_dask(pt_1, pt_3, pt_2, pt_4, 0.0, 0.0) res = _solve_quadratic_dask(res[0], res[1], res[2]).compute() self.assertAlmostEqual(res[0], 0.5, 5) def test_query_no_distance(self): """Test KDTree querying.""" from pyresample.bilinear.xarr import query_no_distance kdtree = mock.MagicMock() kdtree.query.return_value = (1, 2) lons, lats = self.target_def.get_lonlats() voi = (lons >= -180) & (lons <= 180) & (lats <= 90) & (lats >= -90) res = query_no_distance(lons, lats, voi, kdtree, self.neighbours, 0., self.radius) # Only the second value from the query is returned self.assertEqual(res, 2) kdtree.query.assert_called_once() def test_get_valid_input_index_dask(self): """Test finding valid indices for reduced input data.""" from pyresample.bilinear.xarr import _get_valid_input_index_dask # Do not reduce data vii, lons, lats = _get_valid_input_index_dask(self.source_def, self.target_def, False, self.radius) self.assertEqual(vii.shape, (self.source_def.size, )) self.assertTrue(vii.dtype == np.bool) # No data has been reduced, whole input is used self.assertTrue(vii.compute().all()) # Reduce data vii, lons, lats = _get_valid_input_index_dask(self.source_def, self.target_def, True, self.radius) # 2700 valid input points self.assertEqual(vii.compute().sum(), 2700) def test_create_empty_bil_info(self): """Test creation of empty bilinear info.""" from pyresample.bilinear.xarr import _create_empty_bil_info t__, s__, vii, ia_ = _create_empty_bil_info(self.source_def, self.target_def) self.assertEqual(t__.shape, (self.target_def.size, )) self.assertEqual(s__.shape, (self.target_def.size, )) self.assertEqual(ia_.shape, (self.target_def.size, 4)) self.assertTrue(ia_.dtype == np.int32) self.assertEqual(vii.shape, (self.source_def.size, )) self.assertTrue(vii.dtype == np.bool) def test_lonlat2xyz(self): """Test conversion from geographic to cartesian 3D coordinates.""" from pyresample.bilinear.xarr import lonlat2xyz from pyresample import CHUNK_SIZE lons, lats = self.target_def.get_lonlats(chunks=CHUNK_SIZE) res = lonlat2xyz(lons, lats) self.assertEqual(res.shape, (self.target_def.size, 3)) vals = [3188578.91069278, -612099.36103276, 5481596.63569999] self.assertTrue(np.allclose(res.compute()[0, :], vals))
def ensure_dtype(tensor: xr.DataArray, *, dtype): """ Convert array to a given datatype """ return tensor.astype(dtype)
def lazy_indexing(da: xr.DataArray, index: xr.DataArray, dim: Optional[str] = None) -> xr.DataArray: """Get values of `da` at indices `index` in a NaN-aware and lazy manner. Two case Parameters ---------- da : xr.DataArray Input array. If not 1D, `dim` must be given and must not appear in index. index : xr.DataArray N-d integer indices, if da is not 1D, all dimensions of index must be in da dim : str, optional Dimension along which to index, unused if `da` is 1D, should not be present in `index`. Returns ------- xr.DataArray Values of `da` at indices `index` """ if da.ndim == 1: # Case where da is 1D and index is N-D # Slightly better performance using map_blocks, over an apply_ufunc def _index_from_1d_array(array, indices): return array[indices, ] idx_ndim = index.ndim if idx_ndim == 0: # The 0-D index case, we add a dummy dimension to help dask dim = xr.core.utils.get_temp_dimname(da.dims, "x") index = index.expand_dims(dim) invalid = index.isnull() # Which indexes to mask # NaN-indexing doesn't work, so fill with 0 and cast to int index = index.fillna(0).astype(int) # for each chunk of index, take corresponding values from da func = partial(_index_from_1d_array, da) out = index.map_blocks(func) # mask where index was NaN out = out.where(~invalid) if idx_ndim == 0: # 0-D case, drop useless coords and dummy dim out = out.drop_vars(da.dims[0]).squeeze() return out # Case where index.dims is a subset of da.dims. if dim is None: diff_dims = set(da.dims) - set(index.dims) if len(diff_dims) == 0: raise ValueError( "da must have at least one dimension more than index for lazy_indexing." ) if len(diff_dims) > 1: raise ValueError( "If da has more than one dimension more than index, the indexing dim must be given through `dim`" ) dim = diff_dims.pop() def _index_from_nd_array(array, indices): return np.take_along_axis(array, indices[..., np.newaxis], axis=-1)[..., 0] return xr.apply_ufunc( _index_from_nd_array, da, index.astype(int), input_core_dims=[[dim], []], output_core_dims=[[]], dask="parallelized", output_dtypes=[da.dtype], )
def get_optimal_chk( arr: xr.DataArray, dim_grp=[("frame",), ("height", "width")], csize=256, dtype: Optional[type] = None, ) -> dict: """ Compute the optimal chunk size across all dimensions of the input array. This function use `dask` autochunking mechanism to determine the optimal chunk size of an array. The difference between this and directly using "auto" as chunksize is that it understands which dimensions are usually chunked together with the help of `dim_grp`. It also support computing chunks for custom `dtype` and explicit requirement of chunk size. Parameters ---------- arr : xr.DataArray The input array to estimate for chunk size. dim_grp : list, optional List of tuples specifying which dimensions are usually chunked together during computation. For each tuple in the list, it is assumed that only dimensions in the tuple will be chunked while all other dimensions in the input `arr` will not be chunked. Each dimensions in the input `arr` should appear once and only once across the list. By default `[("frame",), ("height", "width")]`. csize : int, optional The desired space each chunk should occupy, specified in MB. By default `256`. dtype : type, optional The datatype of `arr` during actual computation in case that will be different from the current `arr.dtype`. By default `None`. Returns ------- chk : dict Dictionary mapping dimension names to chunk sizes. """ if dtype is not None: arr = arr.astype(dtype) dims = arr.dims if not dim_grp: dim_grp = [(d,) for d in dims] chk_compute = dict() for dg in dim_grp: d_rest = set(dims) - set(dg) dg_dict = {d: "auto" for d in dg} dr_dict = {d: -1 for d in d_rest} dg_dict.update(dr_dict) with da.config.set({"array.chunk-size": "{}MiB".format(csize)}): arr_chk = arr.chunk(dg_dict) chk = get_chunksize(arr_chk) chk_compute.update({d: chk[d] for d in dg}) with da.config.set({"array.chunk-size": "{}MiB".format(csize)}): arr_chk = arr.chunk({d: "auto" for d in dims}) chk_store_da = get_chunksize(arr_chk) chk_store = dict() for d in dims: ncomp = int(arr.sizes[d] / chk_compute[d]) sz = np.array(factors(ncomp)) * chk_compute[d] chk_store[d] = sz[np.argmin(np.abs(sz - chk_store_da[d]))] return chk_compute, chk_store_da