Пример #1
0
    def save_data(
        self,
        subfld: str,
        predict: xr.DataArray,
        probabilites: xr.DataArray,
        geobox_used: GeoBox,
    ):
        """
        save the prediction results to local folder, prepare stac json
        :param subfld: local subfolder to save the prediction tifs`
        :param predict: predicted binary class label array
        :param probabilites: prediction probabilities array
        :param geobox_used: geobox used for the features for prediction
        :return: None
        """
        output_fld, paths, metadata_path = prepare_the_io_path(
            self.config, subfld)
        x, y = get_xy_from_task(subfld)
        if not osp.exists(output_fld):
            os.makedirs(output_fld)

        self._log.info("collecting mask and write cog.")
        write_cog(
            predict.astype(np.uint8).compute(),
            paths["mask"],
            overwrite=True,
        )

        self._log.info("collecting prob and write cog.")
        write_cog(
            probabilites.astype(np.uint8).compute(),
            paths["prob"],
            overwrite=True,
        )

        self._log.info("collecting the stac json and write out.")

        processing_dt = datetime.datetime.now()

        uuid_hex = uuid.uuid4()
        remoe_path = dict((k, osp.basename(p)) for k, p in paths.items())
        remote_metadata_path = metadata_path.replace(self.config.DATA_PATH,
                                                     self.config.REMOTE_PATH)
        stac_doc = StacIntoDc.render_metadata(
            self.config.product,
            geobox_used,
            (x, y),
            self.config.datetime_range,
            uuid_hex,
            remoe_path,
            remote_metadata_path,
            processing_dt,
        )

        with open(metadata_path, "w") as fh:
            json.dump(stac_doc, fh, indent=2)
Пример #2
0
def scale_and_clip_dataarray(dataarray: xr.DataArray,
                             *,
                             scale_factor=1,
                             add_offset=0,
                             clip_range=None,
                             valid_range=None,
                             new_nodata=-999,
                             new_dtype='int16'):
    orig_attrs = dataarray.attrs
    nodata = dataarray.attrs['nodata']

    mask = dataarray.data == nodata

    # add another mask here for if data > 10000 then also make that nodata
    dataarray = dataarray * scale_factor + add_offset

    if clip_range is not None:
        dataarray = dataarray.clip(*clip_range)

    dataarray = dataarray.astype(new_dtype)

    dataarray.data[mask] = new_nodata
    if valid_range is not None:
        valid_min, valid_max = valid_range
        dataarray = dataarray.where(dataarray >= valid_min, new_nodata)
        dataarray = dataarray.where(dataarray <= valid_max, new_nodata)
    dataarray.attrs = orig_attrs
    dataarray.attrs['nodata'] = new_nodata

    return dataarray
    def reapply_mask_to_boolean_xarray(self, variable_of_mask: str,
                                       da: xr.DataArray) -> xr.DataArray:
        """Because boolean comparisons in xarray return False for nan values
        we need to reapply the mask from the original `da` to mask out the sea
        or invalid values (for example).

        Arguments:
        ---------
        variable_of_mask: str
            the variable that you want to use in `self.ds` as the mask.
            The `np.nan` values in `self.ds[variable]` will be marked as `True`

        da: xr.DataArray
            the boolean DataArray (for example `self.exceedences`) to reapply
            the mask to

        Returns:
        -------
        xr.DataArray with dtype of `int`, because `bool` dtype doesn't store
        masks / nan values very well.

        NOTE:
            1. Uses the input dataset for the mask - TODO: does this make
             sense?
        """
        assert da.dtype == np.dtype('bool'), f"This function \
        currrently works on boolean xr.Dataset objects only"

        mask = get_ds_mask(self.ds[variable_of_mask])
        da = da.astype(int).where(~mask)

        return da
Пример #4
0
    def transform_single_date_data(self, data):
        imgdata = Dataset()
        for imgband, components in self.rgb_components.items():
            if callable(components):
                imgband_data = components(data)
                dims = imgband_data.dims
                imgband_data = imgband_data.astype('uint8')
                imgdata[imgband] = (dims, imgband_data)
            else:
                imgband_data = None
                for band, intensity in components.items():
                    if callable(intensity):
                        imgband_component = intensity(data[band], band, imgband)
                    else:
                        imgband_component = data[band] * intensity

                    if imgband_data is not None:
                        imgband_data += imgband_component
                    else:
                        imgband_data = imgband_component
                if imgband_data is None:
                    imgband_data = np.zeros(list(data.dims.values()), 'uint8')
                    imgband_data = DataArray(imgband_data, data.coords, data.dims.keys())
                if imgband != "alpha":
                    imgband_data = self.compress_band(imgband, imgband_data)
                imgdata[imgband] = (imgband_data.dims,
                                    imgband_data.astype("uint8"))
        return imgdata
    def _richardson_lucy_deconv(
            image: xr.DataArray, iterations: int, psf: np.ndarray
    ) -> xr.DataArray:
        """
        Deconvolves input image with a specified point spread function.

        Parameters
        ----------
        image : xr.DataArray
           Input degraded image (can be N dimensional).
        psf : ndarray
           The point spread function.
        iterations : int
           Number of iterations. This parameter plays the role of
           regularisation.

        Returns
        -------
        im_deconv : xr.DataArray
           The deconvolved image.

        """
        # compute the times for direct convolution and the fft method. The fft is of
        # complexity O(N log(N)) for each dimension and the direct method does
        # straight arithmetic (and is O(n*k) to add n elements k times)
        direct_time = np.prod(image.shape + psf.shape)
        fft_time = np.sum([n * np.log(n) for n in image.shape + psf.shape])

        # see whether the fourier transform convolution method or the direct
        # convolution method is faster (discussed in scikit-image PR #1792)
        time_ratio = 40.032 * fft_time / direct_time

        if time_ratio <= 1 or len(image.shape) > 2:
            convolve_method = fftconvolve
        else:
            convolve_method = convolve

        image = image.astype(np.float)
        psf = psf.astype(np.float)
        im_deconv = 0.5 * np.ones(image.shape)
        psf_mirror = psf[::-1, ::-1]

        eps = np.finfo(image.dtype).eps
        for _ in range(iterations):
            x = convolve_method(im_deconv, psf, 'same')
            np.place(x, x == 0, eps)
            relative_blur = image / x + eps
            im_deconv *= convolve_method(relative_blur, psf_mirror, 'same')

        if np.all(np.isnan(im_deconv)):
            raise RuntimeError(
                'All-NaN output data detected. Likely cause is that deconvolution has been run for '
                'too many iterations.')

        return im_deconv
Пример #6
0
def apply_additional_transforms(source: MapSource, agg: xr.DataArray):
    agg = agg.astype('float64')
    agg.data[agg.data == 0] = np.nan
    for e in source.extras:
        if e in additional_transforms:
            trans = additional_transforms.get(e)
            if trans is not None:
                agg = trans(agg)
            else:
                raise ValueError(f'Invalid transform name {e}')

    return source, agg
Пример #7
0
def new(
    cls: type,
    data: Any,
    name: Optional[Name] = None,
    attrs: Optional[Attrs] = None,
    **coords,
) -> DataArray:
    """Create a custom DataArray from data and coordinates.

    {cls.doc}

    Args:
        data: Values of the DataArray. Its shape must match class ``dims``.
            If class ``dtype`` is defined, it will be casted to that type.
            If it cannot be casted, a ``ValueError`` will be raised.
        name: Name of the DataArray. Default is class ``name``.
        attrs: Attributes of the DataArray. Default is class ``attrs``.
        **coords: Coordinates of the DataArray defined by the class.

    Returns:
        Custom DataArray.

    {cls.coords.doc}

    """
    dataarray = DataArray(data, dims=cls.dims, name=name, attrs=attrs)

    if cls.dtype is not None:
        dataarray = dataarray.astype(cls.dtype)

    for name, coord in cls.coords.items():
        shape = [dataarray.sizes[dim] for dim in coord.dims]

        if name in coords:
            dataarray.coords[name] = coord.full(shape, coords[name])
            continue

        if hasattr(cls, name):
            dataarray.coords[name] = coord.full(shape, getattr(cls, name))
            continue

        raise ValueError(
            f"Default value for a coordinate {name} is not defined. "
            f"It must be given as a keyword argument ({name}=<value>).")

    return dataarray
Пример #8
0
    def __init__(
        self,
        ds: xr.DataArray,
        aggregate_dims: list,
    ):
        self._ds = ds if (ds.dtype == np.float64) else ds.astype(np.float64)
        # For some reason, casting to float64 removes all attrs from the dataset
        self._ds.attrs = ds.attrs

        # array metrics
        self._ns_con_var = None
        self._ew_con_var = None
        self._mean = None
        self._mean_abs = None
        self._std = None
        self._prob_positive = None
        self._odds_positive = None
        self._prob_negative = None
        self._zscore = None
        self._mae_max = None
        self._corr_lag1 = None
        self._lag1 = None
        self._agg_dims = aggregate_dims
        self._quantile_value = None
        self._mean_squared = None
        self._root_mean_squared = None
        self._sum = None
        self._sum_squared = None
        self._variance = None
        self._quantile = 0.5
        self._spre_tol = 1.0e-4
        self._max_abs = None
        self._min_abs = None
        self._d_range = None
        self._min_val = None
        self._max_val = None

        # single value metrics
        self._zscore_cutoff = None
        self._zscore_percent_significant = None

        self._frame_size = 1
        if aggregate_dims is not None:
            for dim in aggregate_dims:
                self._frame_size *= int(self._ds.sizes[dim])
Пример #9
0
def write_diff(foo: xr.DataArray, fn, nodata=-1, dtype=None, debug=False):
    dtype = dtype if dtype is not None else foo.dtype
    with rasterio.open(
            fn,
            'w',
            driver='GTiff',
            height=len(foo[foo.dims[1]]),
            width=len(foo[foo.dims[0]]),
            count=1,  #len(bands),
            dtype=dtype,
            crs=foo.crs,
            transform=foo.transform[:6],
            nodata=nodata,
    ) as dst:
        dst.write(foo.astype(dtype).values, 1)

        # Copy metadata
        #for attr in foo.attrs:
        #    dst.update_tags(attr=foo.attrs[attr])
        dst.update_tags(**foo.attrs)

        #                 # TODO: Can we add colormap so we don't have to do it in e.g. QGIS? (may not work w/ more than 256 levels)
        #                 dst.write_colormap(1, {
        #                     0: (255, 0, 0, 255),
        #                     255: (0, 0, 255, 255),
        #                 })

        dst.close()

    if debug:
        tmp = xr.open_rasterio('rep/landcover-hack/bleh_diff.tif').isel(
            band=0, drop=True).rename('tmp')
        print(tmp)
        tmp.plot(figsize=(
            10, 10
        ))  # The figsize is just to avoid issue w/ particular jupyter instance
Пример #10
0
def reproject_array(
    arr: xr.DataArray,
    spec: RasterSpec,
    interpolation: Literal["linear", "nearest"] = "nearest",
    fill_value: Optional[Union[int, float]] = np.nan,
) -> xr.DataArray:
    """
    Reproject and clip a `~xarray.DataArray` to a new `~.RasterSpec` (CRS, resolution, bounds).

    This interpolates using `xarray.DataArray.interp`, which uses `scipy.interpolate.interpn` internally (no GDAL).
    It is somewhat dask-friendly, in that it at least doesn't trigger immediate computation on the array,
    but it's sub-optimal: the ``x`` and ``y`` dimensions are just merged into a single chunk, then interpolated.

    Since this both eliminates spatial parallelism, and potentially requires significant amounts of memory,
    `reproject_array` is only recommended on arrays with a relatively small number spatial chunks.

    Warning
    -------
    This method is very slow on large arrays due to inefficiencies in generating the dask graphs.
    Additionally, all spatial chunking is lost.

    Parameters
    ----------
    arr:
        Array to reproject. It must have ``epsg``, ``x``, and ``y`` coordinates.
        The ``x`` and ``y`` coordinates are assumed to indicate the top-left corner
        of each pixel, not the center.
    spec:
        The `~.RasterSpec` to reproject to.
    interpolation:
        Interpolation method: ``"linear"`` or ``"nearest"``, default ``"nearest"``.
    fill_value:
        Fill output pixels that fall outside the bounds of ``arr`` with this value (default NaN).

    Returns
    -------
    xarray.DataArray:
        The clipped and reprojected array.
    """
    # TODO this scipy/`interp`-based approach still isn't block-parallel
    # (seems like xarray just rechunks to fuse all the spatial chunks first),
    # so this both won't scale, and can be crazy slow in dask graph construction
    # (and the rechunk probably eliminates any hope of sending an HLG to the scheduler).

    from_epsg = array_epsg(arr)
    if (
        from_epsg == spec.epsg
        and array_bounds(arr) == spec.bounds
        and arr.shape[:-2] == spec.shape
    ):
        return arr

    as_bool = False
    if arr.dtype.kind == "b":
        # `interp` can't handle boolean arrays
        arr = arr.astype("uint8")
        as_bool = True

    # TODO fastpath when there's no overlap? (graph shouldn't have any IO in it.)
    # Or does that already happen?

    # TODO support `xy_coords="center"`? `spec` assumes topleft;
    # if the x/y coords on the array are center, this will be a half
    # pixel off.
    minx, miny, maxx, maxy = spec.bounds
    height, width = spec.shape

    x = np.linspace(minx, maxx, width, endpoint=False)
    y = np.linspace(maxy, miny, height, endpoint=False)

    if from_epsg == spec.epsg:
        # Simpler case: just interpolate within the same CRS
        result = arr.interp(
            x=x, y=y, method=interpolation, kwargs=dict(fill_value=fill_value)
        )
        return result.astype(bool) if as_bool else result

    # Different CRSs: need to do a 2D interpolation.
    # We do this by, for each point in the output grid, generating
    # the coordinates in the _input_ CRS that correspond to that point.

    reverse_transformer = cached_transformer(
        spec.epsg, from_epsg, skip_equivalent=True, always_xy=True
    )

    xs, ys = np.meshgrid(x, y, copy=False)
    src_xs, src_ys = reverse_transformer.transform(xs, ys, errcheck=True)

    xs_indexer = xr.DataArray(src_xs, dims=["y", "x"], coords=dict(y=y, x=x))
    ys_indexer = xr.DataArray(src_ys, dims=["y", "x"], coords=dict(y=y, x=x))

    # TODO maybe just drop old dims instead?
    old_xdim = f"x_{from_epsg}"
    old_ydim = f"y_{from_epsg}"

    result = arr.rename(x=old_xdim, y=old_ydim).interp(
        {old_xdim: xs_indexer, old_ydim: ys_indexer},
        method=interpolation,
        kwargs=dict(fill_value=fill_value),
    )
    return result.astype(bool) if as_bool else result
Пример #11
0
def asarray(x, target, dims=None):
    from xarray import DataArray
    from limix._bits.dask import is_dataframe as is_dask_dataframe
    from limix._bits.dask import is_array as is_dask_array
    from limix._bits.dask import is_series as is_dask_series
    from limix._bits.dask import array_shape_reveal
    from limix._bits.xarray import is_dataarray
    from ._conf import CONF
    from numpy import issubdtype, integer

    if target not in CONF["targets"]:
        raise ValueError(f"Unknown target name: {target}.")

    import dask.array as da
    import xarray as xr

    if is_dask_dataframe(x) or is_dask_series(x):
        xidx = x.index.compute()
        x = da.asarray(x)
        x = array_shape_reveal(x)
        x0 = xr.DataArray(x)
        x0.coords[x0.dims[0]] = xidx
        if is_dask_dataframe(x):
            x0.coords[x0.dims[1]] = x.columns
        x = x0
    elif is_dask_array(x):
        x = array_shape_reveal(x)
        x = xr.DataArray(x)

    if not is_dataarray(x):
        x = DataArray(x)

    x.name = target

    while x.ndim < 2:
        rdims = set(CONF["data_dims"]["trait"]).intersection(
            set(x.coords.keys()))
        rdims = rdims - set(x.dims)
        if len(rdims) == 1:
            dim = rdims.pop()
        else:
            dim = "dim_{}".format(x.ndim)
        x = x.expand_dims(dim, x.ndim)

    if isinstance(dims, (tuple, list)):
        dims = {a: n for a, n in enumerate(dims)}
    dims = _numbered_axes(dims)
    if len(set(dims.values())) < len(dims.values()):
        raise ValueError("`dims` must not contain duplicated values.")

    x = x.rename({x.dims[axis]: name for axis, name in dims.items()})
    x = _set_missing_dim(x, CONF["data_dims"][target])
    x = x.transpose(*CONF["data_dims"][target])

    if issubdtype(x.dtype, integer):
        x = x.astype(float)

    for dim in x.dims:
        if x.coords[dim].dtype.kind in {"U", "S"}:
            x.coords[dim].values = x.coords[dim].values.astype(object)

    return x
Пример #12
0
def run_bounds(mask: xr.DataArray,
               dim: str = "time",
               coord: Optional[Union[bool, str]] = True):
    """Return the start and end dates of boolean runs along a dimension.

    Parameters
    ----------
    mask : xr.DataArray
      Boolean array.
    dim : str
      Dimension along which to look for runs.
    coord : bool or str
      If True, return values of the coordinate, if a string, returns values from `dim.dt.<coord>`.
      if False, return indexes.

    Returns
    -------
    xr.DataArray
      With ``dim`` reduced to "events" and "bounds". The events dim is as long as needed, padded with NaN or NaT.
    """
    if isinstance(mask.data, dsk.Array):
        raise NotImplementedError(
            "Dask arrays not supported as we can't know the final event number before computing."
        )

    diff = xr.concat((mask.isel({
        dim: 0
    }).astype(int), mask.astype(int).diff(dim)), dim)

    nstarts = (diff == 1).sum(dim).max().item()

    def _get_indices(arr, *, N):
        out = np.full((N, ), np.nan, dtype=float)
        inds = np.where(arr)[0]
        out[:len(inds)] = inds
        return out

    starts = xr.apply_ufunc(
        _get_indices,
        diff == 1,
        input_core_dims=[[dim]],
        output_core_dims=[["events"]],
        kwargs={"N": nstarts},
        vectorize=True,
    )

    ends = xr.apply_ufunc(
        _get_indices,
        diff == -1,
        input_core_dims=[[dim]],
        output_core_dims=[["events"]],
        kwargs={"N": nstarts},
        vectorize=True,
    )

    if coord:
        crd = mask[dim]
        if isinstance(coord, str):
            crd = getattr(crd.dt, coord)

        starts = lazy_indexing(crd, starts).drop(dim)
        ends = lazy_indexing(crd, ends).drop(dim)
    return xr.concat((starts, ends), "bounds")
Пример #13
0
def first_run(
    da: xr.DataArray,
    window: int,
    dim: str = "time",
    coord: Optional[Union[str, bool]] = False,
    ufunc_1dim: Union[str, bool] = "auto",
):
    """Return the index of the first item of the first run of at least a given length.

    Parameters
    ----------
    da : xr.DataArray
      Input N-dimensional DataArray (boolean)
    window : int
      Minimum duration of consecutive run to accumulate values.
    dim : str
      Dimension along which to calculate consecutive run (default: 'time').
    coord : Optional[str]
      If not False, the function returns values along `dim` instead of indexes.
      If `dim` has a datetime dtype, `coord` can also be a str of the name of the
      DateTimeAccessor object to use (ex: 'dayofyear').
    ufunc_1dim : Union[str, bool]
      Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
      usage based on number of data points.  Using 1D_ufunc=True is typically more efficient
      for dataarray with a small number of gridpoints.

    Returns
    -------
    out : xr.DataArray
      Index (or coordinate if `coord` is not False) of first item in first valid run. Returns np.nan if there are no valid run.
    """
    if ufunc_1dim == "auto":
        if isinstance(da.data,
                      dsk.Array) and len(da.chunks[da.dims.index(dim)]) > 1:
            ufunc_1dim = False
        else:
            npts = get_npts(da)
            ufunc_1dim = npts <= npts_opt

    da = da.fillna(
        0)  # We expect a boolean array, but there could be NaNs nonetheless

    if ufunc_1dim:
        out = first_run_ufunc(x=da, window=window, dim=dim)

    else:
        da = da.astype("int")
        i = xr.DataArray(np.arange(da[dim].size), dims=dim)
        ind = xr.broadcast(i, da)[0].transpose(*da.dims)
        if isinstance(da.data, dsk.Array):
            ind = ind.chunk(da.chunks)
        wind_sum = da.rolling(time=window).sum(allow_lazy=True, skipna=False)
        out = ind.where(wind_sum >= window).min(dim=dim) - (
            window - 1
        )  # remove window -1 as rolling result index is last element of the moving window

    if coord:
        crd = da[dim]
        if isinstance(coord, str):
            crd = getattr(crd.dt, coord)

        out = lazy_indexing(crd, out)

    if dim in out:
        out = out.drop_vars(dim)

    return out
Пример #14
0
def splev(
    x_new: xarray.DataArray, tck: xarray.Dataset, extrapolate: bool | str = True
) -> xarray.DataArray:
    """Evaluate the B-spline generated with :func:`splrep`.

    :param x_new:
        Any :class:`~xarray.DataArray` with any number of dims, not necessarily
        the original interpolation dim.
        Alternatively, it can be any 1-dimensional array-like; it will be
        automatically converted to a :class:`~xarray.DataArray` on the
        interpolation dim.

    :param xarray.Dataset tck:
        As returned by :func:`splrep`.
        It can have been:

        - transposed (not recommended, as performance will
          drop if c is not C-contiguous)
        - sliced, reordered, or (re)chunked, on any
          dim except the interpolation dim
        - computed from dask to numpy backend
        - round-tripped to disk

    :param extrapolate:
        True
            Extrapolate the first and last polynomial pieces of b-spline
            functions active on the base interval
        False
            Return NaNs outside of the base interval
        'periodic'
            Periodic extrapolation is used
        'clip'
            Return y[0] and y[-1] outside of the base interval

    :returns:
        :class:`~xarray.DataArray` with all dims of the interpolated array,
        minus the interpolation dim, plus all dims of x_new

    See :func:`splrep` for usage example.
    """
    # Pre-process x_new into a DataArray
    if not isinstance(x_new, xarray.DataArray):
        if not isinstance(x_new, dask_array_type):
            x_new = np.array(x_new)
        if x_new.ndim == 0:
            dims = []
        elif x_new.ndim == 1:
            dims = [tck.spline_dim]
        else:
            raise ValueError(
                "N-dimensional x_new is only supported if " "x_new is a DataArray"
            )
        x_new = xarray.DataArray(x_new, dims=dims, coords={tck.spline_dim: x_new})

    dim = tck.spline_dim
    t = tck.t
    c = tck.c
    k = tck.k

    invalid_dims = {*x_new.dims} & {*c.dims} - {dim}
    if invalid_dims:
        raise ValueError(
            "Overlapping dims between interpolated "
            "array and x_new: %s" % ",".join(str(d) for d in invalid_dims)
        )

    if t.shape != (c.sizes[dim] + k + 1,):
        raise ValueError("Interpolated dimension has been sliced")

    if x_new.dtype.kind == "M":
        # Note that we're modifying the x_new values, not the x_new coords
        # xarray datetime objects are always in ns
        x_new = x_new.astype(float)

    if extrapolate == "clip":
        x = tck.coords[dim].values
        if x.dtype.kind == "M":
            x = x.astype("M8[ns]").astype(float)
        x_new = np.clip(x_new, x[0].tolist(), x[-1].tolist())
        extrapolate = False

    if c.dims[0] != dim:
        c = c.transpose(dim, *[d for d in c.dims if d != dim])

    if any(isinstance(v.data, dask_array_type) for v in (x_new, t, c)):
        if t.chunks and len(t.chunks[0]) > 1:
            raise NotImplementedError(
                "Unsupported: multiple chunks on interpolation dim"
            )
        if c.chunks and len(c.chunks[0]) > 1:
            raise NotImplementedError(
                "Unsupported: multiple chunks on interpolation dim"
            )

        # omitting t and c
        x_new_axes = "abdefghijklm"[: x_new.ndim]
        c_axes = "nopqrsuvwxyz"[: c.ndim - 1]

        y_new = da.blockwise(
            kernels.splev,
            x_new_axes + c_axes,
            x_new.data,
            x_new_axes,
            t.data,
            "t",
            c.data,
            "c" + c_axes,
            k=k,
            extrapolate=extrapolate,
            concatenate=True,
            dtype=float,
        )
    else:
        y_new = kernels.splev(
            x_new.values, t.values, c.values, k, extrapolate=extrapolate
        )

    y_new = xarray.DataArray(y_new, dims=x_new.dims + c.dims[1:], coords=x_new.coords)
    y_new.coords.update({k: c for k, c in c.coords.items() if dim not in c.dims})
    return y_new
Пример #15
0
class TestXarrayBilinear(unittest.TestCase):
    """Test Xarra/Dask -based bilinear interpolation."""
    def setUp(self):
        """Do some setup for common things."""
        import dask.array as da
        from xarray import DataArray
        from pyresample import geometry, kd_tree

        self.pts_irregular = (np.array([
            [-1., 1.],
        ]), np.array([
            [1., 2.],
        ]), np.array([
            [-2., -1.],
        ]), np.array([
            [2., -4.],
        ]))
        self.pts_vert_parallel = (np.array([
            [-1., 1.],
        ]), np.array([
            [1., 2.],
        ]), np.array([
            [-1., -1.],
        ]), np.array([
            [1., -2.],
        ]))
        self.pts_both_parallel = (np.array([
            [-1., 1.],
        ]), np.array([
            [1., 1.],
        ]), np.array([
            [-1., -1.],
        ]), np.array([
            [1., -1.],
        ]))

        # Area definition with four pixels
        self.target_def = geometry.AreaDefinition(
            'areaD', 'Europe (3km, HRV, VTC)', 'areaD', {
                'a': '6378144.0',
                'b': '6356759.0',
                'lat_0': '50.00',
                'lat_ts': '50.00',
                'lon_0': '8.00',
                'proj': 'stere'
            }, 4, 4,
            [-1370912.72, -909968.64000000001, 1029087.28, 1490031.3600000001])

        # Input data around the target pixel at 0.63388324, 55.08234642,
        in_shape = (100, 100)
        self.data1 = DataArray(da.ones((in_shape[0], in_shape[1])),
                               dims=('y', 'x'))
        self.data2 = 2. * self.data1
        self.data3 = self.data1 + 9.5
        lons, lats = np.meshgrid(np.linspace(-25., 40., num=in_shape[0]),
                                 np.linspace(45., 75., num=in_shape[1]))
        self.source_def = geometry.SwathDefinition(lons=lons, lats=lats)

        self.radius = 50e3
        self.neighbours = 32
        valid_input_index, output_idxs, index_array, dists = \
            kd_tree.get_neighbour_info(self.source_def, self.target_def,
                                       self.radius, neighbours=self.neighbours,
                                       nprocs=1)
        input_size = valid_input_index.sum()
        index_mask = (index_array == input_size)
        index_array = np.where(index_mask, 0, index_array)

        self.valid_input_index = valid_input_index
        self.index_array = index_array

        shp = self.source_def.shape
        self.cols, self.lines = np.meshgrid(np.arange(shp[1]),
                                            np.arange(shp[0]))

    def test_init(self):
        """Test that the resampler has been initialized correctly."""
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        # With defaults
        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)
        self.assertTrue(resampler.source_geo_def == self.source_def)
        self.assertTrue(resampler.target_geo_def == self.target_def)
        self.assertEqual(resampler.radius_of_influence, self.radius)
        self.assertEqual(resampler.neighbours, 32)
        self.assertEqual(resampler.epsilon, 0)
        self.assertTrue(resampler.reduce_data)
        # These should be None
        self.assertIsNone(resampler.valid_input_index)
        self.assertIsNone(resampler.valid_output_index)
        self.assertIsNone(resampler.index_array)
        self.assertIsNone(resampler.distance_array)
        self.assertIsNone(resampler.bilinear_t)
        self.assertIsNone(resampler.bilinear_s)
        self.assertIsNone(resampler.slices_x)
        self.assertIsNone(resampler.slices_y)
        self.assertIsNone(resampler.mask_slices)
        self.assertIsNone(resampler.out_coords_x)
        self.assertIsNone(resampler.out_coords_y)
        # self.slices_{x,y} are used in self.slices dict
        self.assertTrue(resampler.slices['x'] is resampler.slices_x)
        self.assertTrue(resampler.slices['y'] is resampler.slices_y)
        # self.out_coords_{x,y} are used in self.out_coords dict
        self.assertTrue(resampler.out_coords['x'] is resampler.out_coords_x)
        self.assertTrue(resampler.out_coords['y'] is resampler.out_coords_y)

        # Override defaults
        resampler = XArrayResamplerBilinear(self.source_def,
                                            self.target_def,
                                            self.radius,
                                            neighbours=16,
                                            epsilon=0.1,
                                            reduce_data=False)
        self.assertEqual(resampler.neighbours, 16)
        self.assertEqual(resampler.epsilon, 0.1)
        self.assertFalse(resampler.reduce_data)

    def test_get_bil_info(self):
        """Test calculation of bilinear info."""
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        def _check_ts(t__, s__, nans):
            for i, _ in enumerate(t__):
                # Just check the exact value for one pixel
                if i == 5:
                    self.assertAlmostEqual(t__[i], 0.730659147133, 5)
                    self.assertAlmostEqual(s__[i], 0.310314173004, 5)
                # These pixels are outside the area
                elif i in nans:
                    self.assertTrue(np.isnan(t__[i]))
                    self.assertTrue(np.isnan(s__[i]))
                # All the others should have values between 0.0 and 1.0
                else:
                    self.assertTrue(t__[i] >= 0.0)
                    self.assertTrue(s__[i] >= 0.0)
                    self.assertTrue(t__[i] <= 1.0)
                    self.assertTrue(s__[i] <= 1.0)

        # Data reduction enabled (default)
        resampler = XArrayResamplerBilinear(self.source_def,
                                            self.target_def,
                                            self.radius,
                                            reduce_data=True)
        (t__, s__, slices, mask_slices, out_coords) = resampler.get_bil_info()
        _check_ts(t__.compute(), s__.compute(), [3, 10, 12, 13, 14, 15])

        # Nothing should be masked based on coordinates
        self.assertTrue(np.all(~mask_slices))
        # Four values per output location
        self.assertEqual(mask_slices.shape, (self.target_def.size, 4))

        # self.slices_{x,y} are used in self.slices dict so they
        # should be the same (object)
        self.assertTrue(isinstance(slices, dict))
        self.assertTrue(resampler.slices['x'] is resampler.slices_x)
        self.assertTrue(np.all(resampler.slices['x'] == slices['x']))
        self.assertTrue(resampler.slices['y'] is resampler.slices_y)
        self.assertTrue(np.all(resampler.slices['y'] == slices['y']))

        # self.slices_{x,y} are used in self.slices dict so they
        # should be the same (object)
        self.assertTrue(isinstance(out_coords, dict))
        self.assertTrue(resampler.out_coords['x'] is resampler.out_coords_x)
        self.assertTrue(np.all(resampler.out_coords['x'] == out_coords['x']))
        self.assertTrue(resampler.out_coords['y'] is resampler.out_coords_y)
        self.assertTrue(np.all(resampler.out_coords['y'] == out_coords['y']))

        # Also some other attributes should have been set
        self.assertTrue(t__ is resampler.bilinear_t)
        self.assertTrue(s__ is resampler.bilinear_s)
        self.assertIsNotNone(resampler.valid_output_index)
        self.assertIsNotNone(resampler.index_array)
        self.assertIsNotNone(resampler.valid_input_index)

        # Data reduction disabled
        resampler = XArrayResamplerBilinear(self.source_def,
                                            self.target_def,
                                            self.radius,
                                            reduce_data=False)
        (t__, s__, slices, mask_slices, out_coords) = resampler.get_bil_info()
        _check_ts(t__.compute(), s__.compute(), [10, 12, 13, 14, 15])

    def test_get_sample_from_bil_info(self):
        """Test bilinear interpolation as a whole."""
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)
        _ = resampler.get_bil_info()

        # Sample from data1
        res = resampler.get_sample_from_bil_info(self.data1)
        res = res.compute()
        # Check couple of values
        self.assertEqual(res.values[1, 1], 1.)
        self.assertTrue(np.isnan(res.values[0, 3]))
        # Check that the values haven't gone down or up a lot
        self.assertAlmostEqual(np.nanmin(res.values), 1.)
        self.assertAlmostEqual(np.nanmax(res.values), 1.)
        # Check that dimensions are the same
        self.assertEqual(res.dims, self.data1.dims)

        # Sample from data1, custom fill value
        res = resampler.get_sample_from_bil_info(self.data1, fill_value=-1.0)
        res = res.compute()
        self.assertEqual(np.nanmin(res.values), -1.)

        # Sample from integer data
        res = resampler.get_sample_from_bil_info(self.data1.astype(np.uint8),
                                                 fill_value=None)
        res = res.compute()
        # Five values should be filled with zeros, which is the
        # default fill_value for integer data
        self.assertEqual(np.sum(res == 0), 6)

    @mock.patch('pyresample.bilinear.xarr.setattr')
    def test_compute_indices(self, mock_setattr):
        """Test running .compute() for indices."""
        from pyresample.bilinear.xarr import (XArrayResamplerBilinear,
                                              CACHE_INDICES)

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)

        # Set indices to Numpy arrays
        for idx in CACHE_INDICES:
            setattr(resampler, idx, np.array([]))
        resampler._compute_indices()
        # None of the indices shouldn't have been reassigned
        mock_setattr.assert_not_called()

        # Set indices to a Mock object
        arr = mock.MagicMock()
        for idx in CACHE_INDICES:
            setattr(resampler, idx, arr)
        resampler._compute_indices()
        # All the indices should have been reassigned
        self.assertEqual(mock_setattr.call_count, len(CACHE_INDICES))
        # The compute should have been called the same amount of times
        self.assertEqual(arr.compute.call_count, len(CACHE_INDICES))

    def test_add_missing_coordinates(self):
        """Test coordinate updating."""
        import dask.array as da
        from xarray import DataArray
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)
        bands = ['R', 'G', 'B']
        data = DataArray(da.ones((3, 10, 10)),
                         dims=('bands', 'y', 'x'),
                         coords={
                             'bands': bands,
                             'y': np.arange(10),
                             'x': np.arange(10)
                         })
        resampler._add_missing_coordinates(data)
        # X and Y coordinates should not change
        self.assertIsNone(resampler.out_coords_x)
        self.assertIsNone(resampler.out_coords_y)
        self.assertIsNone(resampler.out_coords['x'])
        self.assertIsNone(resampler.out_coords['y'])
        self.assertTrue('bands' in resampler.out_coords)
        self.assertTrue(np.all(resampler.out_coords['bands'] == bands))

    def test_slice_data(self):
        """Test slicing the data."""
        import dask.array as da
        from xarray import DataArray
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)

        # Too many dimensions
        data = DataArray(da.ones((1, 3, 10, 10)))
        with self.assertRaises(ValueError):
            _ = resampler._slice_data(data, np.nan)

        # 2D data
        data = DataArray(da.ones((10, 10)))
        resampler.slices_x = np.random.randint(0, 10, (100, 4))
        resampler.slices_y = np.random.randint(0, 10, (100, 4))
        resampler.mask_slices = np.zeros((100, 4), dtype=np.bool)
        p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan)
        self.assertEqual(p_1.shape, (100, ))
        self.assertTrue(p_1.shape == p_2.shape == p_3.shape == p_4.shape)
        self.assertTrue(
            np.all(p_1 == 1.0) and np.all(p_2 == 1.0) and np.all(p_3 == 1.0)
            and np.all(p_4 == 1.0))

        # 2D data with masking
        resampler.mask_slices = np.ones((100, 4), dtype=np.bool)
        p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan)
        self.assertTrue(
            np.all(np.isnan(p_1)) and np.all(np.isnan(p_2))
            and np.all(np.isnan(p_3)) and np.all(np.isnan(p_4)))

        # 3D data
        data = DataArray(da.ones((3, 10, 10)))
        resampler.slices_x = np.random.randint(0, 10, (100, 4))
        resampler.slices_y = np.random.randint(0, 10, (100, 4))
        resampler.mask_slices = np.zeros((100, 4), dtype=np.bool)
        p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan)
        self.assertEqual(p_1.shape, (3, 100))
        self.assertTrue(p_1.shape == p_2.shape == p_3.shape == p_4.shape)

        # 3D data with masking
        resampler.mask_slices = np.ones((100, 4), dtype=np.bool)
        p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan)
        self.assertTrue(
            np.all(np.isnan(p_1)) and np.all(np.isnan(p_2))
            and np.all(np.isnan(p_3)) and np.all(np.isnan(p_4)))

    @mock.patch('pyresample.bilinear.xarr.np.meshgrid')
    def test_get_slices(self, meshgrid):
        """Test slice array creation."""
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        meshgrid.return_value = (self.cols, self.lines)

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)
        resampler.valid_input_index = self.valid_input_index
        resampler.index_array = self.index_array

        resampler._get_slices()
        self.assertIsNotNone(resampler.out_coords_x)
        self.assertIsNotNone(resampler.out_coords_y)
        self.assertTrue(resampler.out_coords_x is resampler.out_coords['x'])
        self.assertTrue(resampler.out_coords_y is resampler.out_coords['y'])
        self.assertTrue(
            np.allclose(resampler.out_coords_x,
                        [-1070912.72, -470912.72, 129087.28, 729087.28]))
        self.assertTrue(
            np.allclose(resampler.out_coords_y,
                        [1190031.36, 590031.36, -9968.64, -609968.64]))

        self.assertIsNotNone(resampler.slices_x)
        self.assertIsNotNone(resampler.slices_y)
        self.assertTrue(resampler.slices_x is resampler.slices['x'])
        self.assertTrue(resampler.slices_y is resampler.slices['y'])
        self.assertTrue(resampler.slices_x.shape == (self.target_def.size, 32))
        self.assertTrue(resampler.slices_y.shape == (self.target_def.size, 32))
        self.assertEqual(np.sum(resampler.slices_x), 12471)
        self.assertEqual(np.sum(resampler.slices_y), 2223)

        self.assertFalse(np.any(resampler.mask_slices))

        # Ensure that source geo def is used in masking
        # Setting target_geo_def to 0-size shouldn't cause any masked values
        resampler.target_geo_def = np.array([])
        resampler._get_slices()
        self.assertFalse(np.any(resampler.mask_slices))
        # Setting source area def to 0-size should mask all values
        resampler.source_geo_def = np.array([[]])
        resampler._get_slices()
        self.assertTrue(np.all(resampler.mask_slices))

    @mock.patch('pyresample.bilinear.xarr.KDTree')
    def test_create_resample_kdtree(self, KDTree):
        """Test that KDTree creation is called."""
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)

        vii, kdtree = resampler._create_resample_kdtree()
        self.assertEqual(np.sum(vii), 2700)
        self.assertEqual(vii.size, self.source_def.size)
        KDTree.assert_called_once()

    @mock.patch('pyresample.bilinear.xarr.query_no_distance')
    def test_query_resample_kdtree(self, qnd):
        """Test that query_no_distance is called in _query_resample_kdtree()."""
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
                                            self.radius)
        res, none = resampler._query_resample_kdtree(1, 2, 3, 4, reduce_data=5)
        qnd.assert_called_with(2, 3, 4, 1, resampler.neighbours,
                               resampler.epsilon,
                               resampler.radius_of_influence)

    def test_get_input_xy_dask(self):
        """Test computation of input X and Y coordinates in target proj."""
        import dask.array as da
        from pyresample.bilinear.xarr import _get_input_xy_dask
        from pyresample._spatial_mp import Proj

        proj = Proj(self.target_def.proj_str)
        in_x, in_y = _get_input_xy_dask(self.source_def, proj,
                                        da.from_array(self.valid_input_index),
                                        da.from_array(self.index_array))

        self.assertTrue(in_x.shape, (self.target_def.size, 32))
        self.assertTrue(in_y.shape, (self.target_def.size, 32))
        self.assertTrue(in_x.all())
        self.assertTrue(in_y.all())

    def test_mask_coordinates_dask(self):
        """Test masking of invalid coordinates."""
        import dask.array as da
        from pyresample.bilinear.xarr import _mask_coordinates_dask

        lons, lats = _mask_coordinates_dask(
            da.from_array([-200., 0., 0., 0., 200.]),
            da.from_array([0., -100., 0, 100., 0.]))
        lons, lats = da.compute(lons, lats)
        self.assertTrue(lons[2] == lats[2] == 0.0)
        self.assertEqual(np.sum(np.isnan(lons)), 4)
        self.assertEqual(np.sum(np.isnan(lats)), 4)

    def test_get_bounding_corners_dask(self):
        """Test finding surrounding bounding corners."""
        import dask.array as da
        from pyresample.bilinear.xarr import (_get_input_xy_dask,
                                              _get_bounding_corners_dask)
        from pyresample._spatial_mp import Proj
        from pyresample import CHUNK_SIZE

        proj = Proj(self.target_def.proj_str)
        out_x, out_y = self.target_def.get_proj_coords(chunks=CHUNK_SIZE)
        out_x = da.ravel(out_x)
        out_y = da.ravel(out_y)
        in_x, in_y = _get_input_xy_dask(self.source_def, proj,
                                        da.from_array(self.valid_input_index),
                                        da.from_array(self.index_array))
        pt_1, pt_2, pt_3, pt_4, ia_ = _get_bounding_corners_dask(
            in_x, in_y, out_x, out_y, self.neighbours,
            da.from_array(self.index_array))

        self.assertTrue(pt_1.shape == pt_2.shape == pt_3.shape == pt_4.shape ==
                        (self.target_def.size, 2))
        self.assertTrue(ia_.shape == (self.target_def.size, 4))

        # Check which of the locations has four valid X/Y pairs by
        # finding where there are non-NaN values
        res = da.sum(pt_1 + pt_2 + pt_3 + pt_4, axis=1).compute()
        self.assertEqual(np.sum(~np.isnan(res)), 10)

    def test_get_corner_dask(self):
        """Test finding the closest corners."""
        import dask.array as da
        from pyresample.bilinear.xarr import (_get_corner_dask,
                                              _get_input_xy_dask)
        from pyresample import CHUNK_SIZE
        from pyresample._spatial_mp import Proj

        proj = Proj(self.target_def.proj_str)
        in_x, in_y = _get_input_xy_dask(self.source_def, proj,
                                        da.from_array(self.valid_input_index),
                                        da.from_array(self.index_array))
        out_x, out_y = self.target_def.get_proj_coords(chunks=CHUNK_SIZE)
        out_x = da.ravel(out_x)
        out_y = da.ravel(out_y)

        # Some copy&paste from the code to get the input
        out_x_tile = np.reshape(np.tile(out_x, self.neighbours),
                                (self.neighbours, out_x.size)).T
        out_y_tile = np.reshape(np.tile(out_y, self.neighbours),
                                (self.neighbours, out_y.size)).T
        x_diff = out_x_tile - in_x
        y_diff = out_y_tile - in_y
        stride = np.arange(x_diff.shape[0])

        # Use lower left source pixels for testing
        valid = (x_diff > 0) & (y_diff > 0)
        x_3, y_3, idx_3 = _get_corner_dask(stride, valid, in_x, in_y,
                                           da.from_array(self.index_array))

        self.assertTrue(
            x_3.shape == y_3.shape == idx_3.shape == (self.target_def.size, ))
        # Four locations have no data to the lower left of them (the
        # bottom row of the area
        self.assertEqual(np.sum(np.isnan(x_3.compute())), 4)

    @mock.patch('pyresample.bilinear.xarr._get_ts_parallellogram_dask')
    @mock.patch('pyresample.bilinear.xarr._get_ts_uprights_parallel_dask')
    @mock.patch('pyresample.bilinear.xarr._get_ts_irregular_dask')
    def test_get_ts_dask(self, irregular, uprights, parallellogram):
        """Test that the three separate functions are called."""
        from pyresample.bilinear.xarr import _get_ts_dask

        # All valid values
        t_irr = np.array([0.1, 0.2, 0.3])
        s_irr = np.array([0.1, 0.2, 0.3])
        irregular.return_value = (t_irr, s_irr)
        t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6)
        irregular.assert_called_once()
        uprights.assert_not_called()
        parallellogram.assert_not_called()
        self.assertTrue(np.allclose(t__.compute(), t_irr))
        self.assertTrue(np.allclose(s__.compute(), s_irr))

        # NaN in the first step, good value for that location from the
        # second step
        t_irr = np.array([0.1, 0.2, np.nan])
        s_irr = np.array([0.1, 0.2, np.nan])
        irregular.return_value = (t_irr, s_irr)
        t_upr = np.array([3, 3, 0.3])
        s_upr = np.array([3, 3, 0.3])
        uprights.return_value = (t_upr, s_upr)
        t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6)
        self.assertEqual(irregular.call_count, 2)
        uprights.assert_called_once()
        parallellogram.assert_not_called()
        # Only the last value of the first step should have been replaced
        t_res = np.array([0.1, 0.2, 0.3])
        s_res = np.array([0.1, 0.2, 0.3])
        self.assertTrue(np.allclose(t__.compute(), t_res))
        self.assertTrue(np.allclose(s__.compute(), s_res))

        # Two NaNs in the first step, one of which are found by the
        # second, and the last bad value is replaced by the third step
        t_irr = np.array([0.1, np.nan, np.nan])
        s_irr = np.array([0.1, np.nan, np.nan])
        irregular.return_value = (t_irr, s_irr)
        t_upr = np.array([3, np.nan, 0.3])
        s_upr = np.array([3, np.nan, 0.3])
        uprights.return_value = (t_upr, s_upr)
        t_par = np.array([4, 0.2, 0.3])
        s_par = np.array([4, 0.2, 0.3])
        parallellogram.return_value = (t_par, s_par)
        t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6)
        self.assertEqual(irregular.call_count, 3)
        self.assertEqual(uprights.call_count, 2)
        parallellogram.assert_called_once()
        # Only the last two values should have been replaced
        t_res = np.array([0.1, 0.2, 0.3])
        s_res = np.array([0.1, 0.2, 0.3])
        self.assertTrue(np.allclose(t__.compute(), t_res))
        self.assertTrue(np.allclose(s__.compute(), s_res))

        # Too large and small values should be set to NaN
        t_irr = np.array([1.00001, -0.00001, 1e6])
        s_irr = np.array([1.00001, -0.00001, -1e6])
        irregular.return_value = (t_irr, s_irr)
        # Second step also returns invalid values
        t_upr = np.array([1.00001, 0.2, np.nan])
        s_upr = np.array([-0.00001, 0.2, np.nan])
        uprights.return_value = (t_upr, s_upr)
        # Third step has one new valid value, the last will stay invalid
        t_par = np.array([0.1, 0.2, 4.0])
        s_par = np.array([0.1, 0.2, 4.0])
        parallellogram.return_value = (t_par, s_par)
        t__, s__ = _get_ts_dask(1, 2, 3, 4, 5, 6)

        t_res = np.array([0.1, 0.2, np.nan])
        s_res = np.array([0.1, 0.2, np.nan])
        self.assertTrue(np.allclose(t__.compute(), t_res, equal_nan=True))
        self.assertTrue(np.allclose(s__.compute(), s_res, equal_nan=True))

    def test_get_ts_irregular_dask(self):
        """Test calculations for irregular corner locations."""
        from pyresample.bilinear.xarr import _get_ts_irregular_dask

        res = _get_ts_irregular_dask(self.pts_irregular[0],
                                     self.pts_irregular[1],
                                     self.pts_irregular[2],
                                     self.pts_irregular[3], 0., 0.)
        self.assertEqual(res[0], 0.375)
        self.assertEqual(res[1], 0.5)
        res = _get_ts_irregular_dask(self.pts_vert_parallel[0],
                                     self.pts_vert_parallel[1],
                                     self.pts_vert_parallel[2],
                                     self.pts_vert_parallel[3], 0., 0.)
        self.assertTrue(np.isnan(res[0]))
        self.assertTrue(np.isnan(res[1]))

    def test_get_ts_uprights_parallel(self):
        """Test calculation when uprights are parallel."""
        from pyresample.bilinear import _get_ts_uprights_parallel

        res = _get_ts_uprights_parallel(self.pts_vert_parallel[0],
                                        self.pts_vert_parallel[1],
                                        self.pts_vert_parallel[2],
                                        self.pts_vert_parallel[3], 0., 0.)
        self.assertEqual(res[0], 0.5)
        self.assertEqual(res[1], 0.5)

    def test_get_ts_parallellogram(self):
        """Test calculation when the corners form a parallellogram."""
        from pyresample.bilinear import _get_ts_parallellogram

        res = _get_ts_parallellogram(self.pts_both_parallel[0],
                                     self.pts_both_parallel[1],
                                     self.pts_both_parallel[2], 0., 0.)
        self.assertEqual(res[0], 0.5)
        self.assertEqual(res[1], 0.5)

    def test_calc_abc(self):
        """Test calculation of quadratic coefficients."""
        from pyresample.bilinear.xarr import _calc_abc_dask

        # No np.nan inputs
        pt_1, pt_2, pt_3, pt_4 = self.pts_irregular
        res = _calc_abc_dask(pt_1, pt_2, pt_3, pt_4, 0.0, 0.0)
        self.assertFalse(np.isnan(res[0]))
        self.assertFalse(np.isnan(res[1]))
        self.assertFalse(np.isnan(res[2]))
        # np.nan input -> np.nan output
        res = _calc_abc_dask(np.array([[np.nan, np.nan]]), pt_2, pt_3, pt_4,
                             0.0, 0.0)
        self.assertTrue(np.isnan(res[0]))
        self.assertTrue(np.isnan(res[1]))
        self.assertTrue(np.isnan(res[2]))

    def test_solve_quadratic(self):
        """Test solving quadratic equation."""
        from pyresample.bilinear.xarr import (_solve_quadratic_dask,
                                              _calc_abc_dask)

        res = _solve_quadratic_dask(1, 0, 0).compute()
        self.assertEqual(res, 0.0)
        res = _solve_quadratic_dask(1, 2, 1).compute()
        self.assertTrue(np.isnan(res))
        res = _solve_quadratic_dask(1, 2, 1, min_val=-2.).compute()
        self.assertEqual(res, -1.0)
        # Test that small adjustments work
        pt_1, pt_2, pt_3, pt_4 = self.pts_vert_parallel
        pt_1 = self.pts_vert_parallel[0].copy()
        pt_1[0][0] += 1e-7
        res = _calc_abc_dask(pt_1, pt_2, pt_3, pt_4, 0.0, 0.0)
        res = _solve_quadratic_dask(res[0], res[1], res[2]).compute()
        self.assertAlmostEqual(res[0], 0.5, 5)
        res = _calc_abc_dask(pt_1, pt_3, pt_2, pt_4, 0.0, 0.0)
        res = _solve_quadratic_dask(res[0], res[1], res[2]).compute()
        self.assertAlmostEqual(res[0], 0.5, 5)

    def test_query_no_distance(self):
        """Test KDTree querying."""
        from pyresample.bilinear.xarr import query_no_distance

        kdtree = mock.MagicMock()
        kdtree.query.return_value = (1, 2)
        lons, lats = self.target_def.get_lonlats()
        voi = (lons >= -180) & (lons <= 180) & (lats <= 90) & (lats >= -90)
        res = query_no_distance(lons, lats, voi, kdtree, self.neighbours, 0.,
                                self.radius)
        # Only the second value from the query is returned
        self.assertEqual(res, 2)
        kdtree.query.assert_called_once()

    def test_get_valid_input_index_dask(self):
        """Test finding valid indices for reduced input data."""
        from pyresample.bilinear.xarr import _get_valid_input_index_dask

        # Do not reduce data
        vii, lons, lats = _get_valid_input_index_dask(self.source_def,
                                                      self.target_def, False,
                                                      self.radius)
        self.assertEqual(vii.shape, (self.source_def.size, ))
        self.assertTrue(vii.dtype == np.bool)
        # No data has been reduced, whole input is used
        self.assertTrue(vii.compute().all())

        # Reduce data
        vii, lons, lats = _get_valid_input_index_dask(self.source_def,
                                                      self.target_def, True,
                                                      self.radius)
        # 2700 valid input points
        self.assertEqual(vii.compute().sum(), 2700)

    def test_create_empty_bil_info(self):
        """Test creation of empty bilinear info."""
        from pyresample.bilinear.xarr import _create_empty_bil_info

        t__, s__, vii, ia_ = _create_empty_bil_info(self.source_def,
                                                    self.target_def)
        self.assertEqual(t__.shape, (self.target_def.size, ))
        self.assertEqual(s__.shape, (self.target_def.size, ))
        self.assertEqual(ia_.shape, (self.target_def.size, 4))
        self.assertTrue(ia_.dtype == np.int32)
        self.assertEqual(vii.shape, (self.source_def.size, ))
        self.assertTrue(vii.dtype == np.bool)

    def test_lonlat2xyz(self):
        """Test conversion from geographic to cartesian 3D coordinates."""
        from pyresample.bilinear.xarr import lonlat2xyz
        from pyresample import CHUNK_SIZE

        lons, lats = self.target_def.get_lonlats(chunks=CHUNK_SIZE)
        res = lonlat2xyz(lons, lats)
        self.assertEqual(res.shape, (self.target_def.size, 3))
        vals = [3188578.91069278, -612099.36103276, 5481596.63569999]
        self.assertTrue(np.allclose(res.compute()[0, :], vals))
Пример #16
0
def ensure_dtype(tensor: xr.DataArray, *, dtype):
    """
    Convert array to a given datatype
    """
    return tensor.astype(dtype)
Пример #17
0
def lazy_indexing(da: xr.DataArray,
                  index: xr.DataArray,
                  dim: Optional[str] = None) -> xr.DataArray:
    """Get values of `da` at indices `index` in a NaN-aware and lazy manner.

    Two case

    Parameters
    ----------
    da : xr.DataArray
      Input array. If not 1D, `dim` must be given and must not appear in index.
    index : xr.DataArray
      N-d integer indices, if da is not 1D, all dimensions of index must be in da
    dim : str, optional
      Dimension along which to index, unused if `da` is 1D,
      should not be present in `index`.

    Returns
    -------
    xr.DataArray
      Values of `da` at indices `index`
    """
    if da.ndim == 1:
        # Case where da is 1D and index is N-D
        # Slightly better performance using map_blocks, over an apply_ufunc
        def _index_from_1d_array(array, indices):
            return array[indices, ]

        idx_ndim = index.ndim
        if idx_ndim == 0:
            # The 0-D index case, we add a dummy dimension to help dask
            dim = xr.core.utils.get_temp_dimname(da.dims, "x")
            index = index.expand_dims(dim)
        invalid = index.isnull()  # Which indexes to mask
        # NaN-indexing doesn't work, so fill with 0 and cast to int
        index = index.fillna(0).astype(int)
        # for each chunk of index, take corresponding values from da
        func = partial(_index_from_1d_array, da)
        out = index.map_blocks(func)
        # mask where index was NaN
        out = out.where(~invalid)
        if idx_ndim == 0:
            # 0-D case, drop useless coords and dummy dim
            out = out.drop_vars(da.dims[0]).squeeze()
        return out

    # Case where index.dims is a subset of da.dims.
    if dim is None:
        diff_dims = set(da.dims) - set(index.dims)
        if len(diff_dims) == 0:
            raise ValueError(
                "da must have at least one dimension more than index for lazy_indexing."
            )
        if len(diff_dims) > 1:
            raise ValueError(
                "If da has more than one dimension more than index, the indexing dim must be given through `dim`"
            )
        dim = diff_dims.pop()

    def _index_from_nd_array(array, indices):
        return np.take_along_axis(array, indices[..., np.newaxis],
                                  axis=-1)[..., 0]

    return xr.apply_ufunc(
        _index_from_nd_array,
        da,
        index.astype(int),
        input_core_dims=[[dim], []],
        output_core_dims=[[]],
        dask="parallelized",
        output_dtypes=[da.dtype],
    )
Пример #18
0
def get_optimal_chk(
    arr: xr.DataArray,
    dim_grp=[("frame",), ("height", "width")],
    csize=256,
    dtype: Optional[type] = None,
) -> dict:
    """
    Compute the optimal chunk size across all dimensions of the input array.

    This function use `dask` autochunking mechanism to determine the optimal
    chunk size of an array. The difference between this and directly using
    "auto" as chunksize is that it understands which dimensions are usually
    chunked together with the help of `dim_grp`. It also support computing
    chunks for custom `dtype` and explicit requirement of chunk size.

    Parameters
    ----------
    arr : xr.DataArray
        The input array to estimate for chunk size.
    dim_grp : list, optional
        List of tuples specifying which dimensions are usually chunked together
        during computation. For each tuple in the list, it is assumed that only
        dimensions in the tuple will be chunked while all other dimensions in
        the input `arr` will not be chunked. Each dimensions in the input `arr`
        should appear once and only once across the list. By default
        `[("frame",), ("height", "width")]`.
    csize : int, optional
        The desired space each chunk should occupy, specified in MB. By default
        `256`.
    dtype : type, optional
        The datatype of `arr` during actual computation in case that will be
        different from the current `arr.dtype`. By default `None`.

    Returns
    -------
    chk : dict
        Dictionary mapping dimension names to chunk sizes.
    """
    if dtype is not None:
        arr = arr.astype(dtype)
    dims = arr.dims
    if not dim_grp:
        dim_grp = [(d,) for d in dims]
    chk_compute = dict()
    for dg in dim_grp:
        d_rest = set(dims) - set(dg)
        dg_dict = {d: "auto" for d in dg}
        dr_dict = {d: -1 for d in d_rest}
        dg_dict.update(dr_dict)
        with da.config.set({"array.chunk-size": "{}MiB".format(csize)}):
            arr_chk = arr.chunk(dg_dict)
        chk = get_chunksize(arr_chk)
        chk_compute.update({d: chk[d] for d in dg})
    with da.config.set({"array.chunk-size": "{}MiB".format(csize)}):
        arr_chk = arr.chunk({d: "auto" for d in dims})
    chk_store_da = get_chunksize(arr_chk)
    chk_store = dict()
    for d in dims:
        ncomp = int(arr.sizes[d] / chk_compute[d])
        sz = np.array(factors(ncomp)) * chk_compute[d]
        chk_store[d] = sz[np.argmin(np.abs(sz - chk_store_da[d]))]
    return chk_compute, chk_store_da