def test_load_data(tmpdir): tmpdir = Path(str(tmpdir)) group_by = query_group_by('time') spatial = dict(resolution=(15, -15), offset=(11230, 1381110),) nodata = -999 aa = mk_test_image(96, 64, 'int16', nodata=nodata) ds, gbox = gen_tiff_dataset([SimpleNamespace(name='aa', values=aa, nodata=nodata)], tmpdir, prefix='ds1-', timestamp='2018-07-19', **spatial) assert ds.time is not None ds2, _ = gen_tiff_dataset([SimpleNamespace(name='aa', values=aa, nodata=nodata)], tmpdir, prefix='ds2-', timestamp='2018-07-19', **spatial) assert ds.time is not None assert ds.time == ds2.time sources = Datacube.group_datasets([ds], 'time') sources2 = Datacube.group_datasets([ds, ds2], group_by) mm = ['aa'] mm = [ds.type.measurements[k] for k in mm] ds_data = Datacube.load_data(sources, gbox, mm) assert ds_data.aa.nodata == nodata np.testing.assert_array_equal(aa, ds_data.aa.values[0]) custom_fuser_call_count = 0 def custom_fuser(dest, delta): nonlocal custom_fuser_call_count custom_fuser_call_count += 1 dest[:] += delta progress_call_data = [] def progress_cbk(n, nt): progress_call_data.append((n, nt)) ds_data = Datacube.load_data(sources2, gbox, mm, fuse_func=custom_fuser, progress_cbk=progress_cbk) assert ds_data.aa.nodata == nodata assert custom_fuser_call_count > 0 np.testing.assert_array_equal(nodata + aa + aa, ds_data.aa.values[0]) assert progress_call_data == [(1, 2), (2, 2)]
def check_data_with_api(index, time_slices): """Chek retrieved data for specific values. We scale down by 100 and check for predefined values in the corners. """ from datacube import Datacube dc = Datacube(index=index) # Make the retrieved data 100 less granular shape_x = int(GEOTIFF['shape']['x'] / 100.0) shape_y = int(GEOTIFF['shape']['y'] / 100.0) pixel_x = int(GEOTIFF['pixel_size']['x'] * 100) pixel_y = int(GEOTIFF['pixel_size']['y'] * 100) input_type_name = 'ls5_nbar_albers' input_type = dc.index.products.get_by_name(input_type_name) geobox = geometry.GeoBox( shape_x + 1, shape_y + 1, Affine(pixel_x, 0.0, GEOTIFF['ul']['x'], 0.0, pixel_y, GEOTIFF['ul']['y']), geometry.CRS(GEOTIFF['crs'])) observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent) group_by = query_group_by('time') sources = dc.group_datasets(observations, group_by) data = dc.load_data(sources, geobox, input_type.measurements.values()) assert hashlib.md5( data.green.data).hexdigest() == '7f5ace486e88d33edf3512e8de6b6996' assert hashlib.md5( data.blue.data).hexdigest() == 'b58204f1e10dd678b292df188c242c7e' for time_slice in range(time_slices): assert data.blue.values[time_slice][-1, -1] == -999
def native_load(ds, measurements=None, basis=None, **kw): """Load single dataset in native resolution. :param ds: Dataset :param measurements: List of band names to load :param basis: Name of the band to use for computing reference frame, other bands might be reprojected if they use different pixel grid :param **kw: Any other parameter load_data accepts :return: Xarray dataset """ from datacube import Datacube geobox = native_geobox( ds, measurements, basis) # early exit via exception if no compatible grid exists if measurements is not None: mm = [ds.type.measurements[n] for n in measurements] else: mm = ds.type.measurements return Datacube.load_data(Datacube.group_datasets([ds], 'time'), geobox, measurements=mm, **kw)
def fetch(self, grouped: VirtualDatasetBox, **load_settings: Dict[str, Any]) -> xarray.Dataset: """ Convert grouped datasets to `xarray.Dataset`. """ merged = merge_search_terms( select_keys(self, self._LOAD_KEYS), select_keys(load_settings, self._LOAD_KEYS)) if 'measurements' not in self: measurements = load_settings.get('measurements') elif 'measurements' not in load_settings: measurements = self.get('measurements') else: for measurement in load_settings['measurements']: self._assert( measurement in self['measurements'], '{} not found in {}'.format(measurement, self._product)) measurements = load_settings['measurement'] measurements = self.output_measurements(grouped.product_definitions) result = Datacube.load_data(grouped.pile, grouped.geobox, list(measurements.values()), fuse_func=merged.get('fuse_func'), dask_chunks=merged.get('dask_chunks')) return apply_aliases(result, grouped.product_definitions[self._product], list(measurements))
def fetch(self, grouped: VirtualDatasetBox, **load_settings: Dict[str, Any]) -> xarray.Dataset: """ Convert grouped datasets to `xarray.Dataset`. """ load_keys = self._LOAD_KEYS - {'measurements'} merged = merge_search_terms(select_keys(self, load_keys), select_keys(load_settings, load_keys)) product = grouped.product_definitions[self._product] if 'measurements' in self and 'measurements' in load_settings: for measurement in load_settings['measurements']: self._assert( measurement in self['measurements'], '{} not found in {}'.format(measurement, self._product)) measurement_dicts = self.output_measurements( grouped.product_definitions, load_settings.get('measurements')) if grouped.load_natively: canonical_names = [ product.canonical_measurement(measurement) for measurement in measurement_dicts ] dataset_geobox = geobox_union_conservative([ native_geobox(ds, measurements=canonical_names, basis=merged.get('like')) for ds in grouped.box.sum().item() ]) if grouped.geopolygon is not None: reproject_roi = compute_reproject_roi( dataset_geobox, GeoBox.from_geopolygon( grouped.geopolygon, crs=dataset_geobox.crs, align=dataset_geobox.alignment, resolution=dataset_geobox.resolution)) self._assert(reproject_roi.is_st, "native load is not axis-aligned") self._assert(numpy.isclose(reproject_roi.scale, 1.0), "native load should not require scaling") geobox = dataset_geobox[reproject_roi.roi_src] else: geobox = dataset_geobox else: geobox = grouped.geobox result = Datacube.load_data(grouped.box, geobox, list(measurement_dicts.values()), fuse_func=merged.get('fuse_func'), dask_chunks=merged.get('dask_chunks'), resampling=merged.get( 'resampling', 'nearest')) return result
def check_open_with_api(index, time_slices): with rasterio.Env(): from datacube import Datacube dc = Datacube(index=index) input_type_name = 'ls5_nbar_albers' input_type = dc.index.products.get_by_name(input_type_name) geobox = geometry.GeoBox(200, 200, Affine(25, 0.0, 638000, 0.0, -25, 6276000), geometry.CRS('EPSG:28355')) observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent) group_by = query_group_by('time') sources = dc.group_datasets(observations, group_by) data = dc.load_data(sources, geobox, input_type.measurements.values()) assert data.blue.shape == (time_slices, 200, 200) chunk_profile = {'time': 1, 'x': 100, 'y': 100} lazy_data = dc.load_data(sources, geobox, input_type.measurements.values(), dask_chunks=chunk_profile) assert lazy_data.blue.shape == (time_slices, 200, 200) assert (lazy_data.blue.load() == data.blue).all()
def check_open_with_api(driver_manager, time_slices): from datacube import Datacube dc = Datacube(driver_manager=driver_manager) input_type_name = 'ls5_nbar_albers' input_type = dc.index.products.get_by_name(input_type_name) geobox = geometry.GeoBox(200, 200, Affine(25, 0.0, 638000, 0.0, -25, 6276000), geometry.CRS('EPSG:28355')) observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent) group_by = query_group_by('time') sources = dc.group_datasets(observations, group_by) data = dc.load_data(sources, geobox, input_type.measurements.values(), driver_manager=driver_manager) assert data.blue.shape == (time_slices, 200, 200)
def check_open_with_api(index): from datacube import Datacube dc = Datacube(index=index) input_type_name = 'ls5_nbar_albers' input_type = dc.index.products.get_by_name(input_type_name) geobox = GeoBox(200, 200, Affine(25, 0.0, 1500000, 0.0, -25, -3900000), CRS('EPSG:3577')) observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent) group_by = query_group_by('time') sources = dc.group_datasets(observations, group_by) data = dc.load_data(sources, geobox, input_type.measurements.values()) assert data.blue.shape == (1, 200, 200)
def _product_fetch_(): merged = merge_search_terms( select_keys(self, self._LOAD_KEYS), select_keys(load_settings, self._LOAD_KEYS)) # load_settings should not contain `measurements` for now measurements = self.output_measurements(product_definitions) result = Datacube.load_data(grouped.pile, grouped.geobox, list(measurements.values()), fuse_func=merged.get('fuse_func'), dask_chunks=merged.get('dask_chunks')) return apply_aliases(result, product_definitions[self._product], list(measurements))
def check_data_with_api(index, time_slices): """Chek retrieved data for specific values. We scale down by 100 and check for predefined values in the corners. """ from datacube import Datacube dc = Datacube(index=index) # TODO: this test needs to change, it tests that results are exactly the # same as some time before, but with the current zoom out factor it's # hard to verify that results are as expected even with human # judgement. What it should test is that reading native from the # ingested product gives exactly the same results as reading into the # same GeoBox from the original product. Separate to that there # should be a read test that confirms that what you read from native # product while changing projection is of expected value # Make the retrieved data lower res ss = 100 shape_x = int(GEOTIFF['shape']['x'] / ss) shape_y = int(GEOTIFF['shape']['y'] / ss) pixel_x = int(GEOTIFF['pixel_size']['x'] * ss) pixel_y = int(GEOTIFF['pixel_size']['y'] * ss) input_type_name = 'ls5_nbar_albers' input_type = dc.index.products.get_by_name(input_type_name) geobox = geometry.GeoBox( shape_x + 2, shape_y + 2, Affine(pixel_x, 0.0, GEOTIFF['ul']['x'], 0.0, pixel_y, GEOTIFF['ul']['y']), geometry.CRS(GEOTIFF['crs'])) observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent) group_by = query_group_by('time') sources = dc.group_datasets(observations, group_by) data = dc.load_data(sources, geobox, input_type.measurements.values()) assert hashlib.md5( data.green.data).hexdigest() == '0f64647bad54db4389fb065b2128025e' assert hashlib.md5( data.blue.data).hexdigest() == '41a7b50dfe5c4c1a1befbc378225beeb' for time_slice in range(time_slices): assert data.blue.values[time_slice][-1, -1] == -999
def _load_with_native_transform_1( sources: xr.DataArray, bands: Tuple[str, ...], geobox: GeoBox, native_transform: Callable[[xr.Dataset], xr.Dataset], basis: Optional[str] = None, groupby: Optional[str] = None, fuser: Optional[Callable[[xr.Dataset], xr.Dataset]] = None, resampling: str = "nearest", chunks: Optional[Dict[str, int]] = None, load_chunks: Optional[Dict[str, int]] = None, pad: Optional[int] = None, ) -> xr.Dataset: if basis is None: basis = bands[0] if load_chunks is None: load_chunks = chunks (ds, ) = sources.data[0] load_geobox = compute_native_load_geobox(geobox, ds, basis) if pad is not None: load_geobox = gbox.pad(load_geobox, pad) mm = ds.type.lookup_measurements(bands) xx = Datacube.load_data(sources, load_geobox, mm, dask_chunks=load_chunks) xx = native_transform(xx) if groupby is not None: if fuser is None: fuser = _nodata_fuser # type: ignore xx = xx.groupby(groupby).map(fuser) _chunks = None if chunks is not None: _chunks = tuple(chunks.get(ax, -1) for ax in ("y", "x")) return xr_reproject(xx, geobox, chunks=_chunks, resampling=resampling) # type: ignore
def test_load_data_cbk(tmpdir): from datacube.api import TerminateCurrentLoad tmpdir = Path(str(tmpdir)) spatial = dict( resolution=(15, -15), offset=(11230, 1381110), ) nodata = -999 aa = mk_test_image(96, 64, 'int16', nodata=nodata) bands = [ SimpleNamespace(name=name, values=aa, nodata=nodata) for name in ['aa', 'bb'] ] ds, gbox = gen_tiff_dataset(bands, tmpdir, prefix='ds1-', timestamp='2018-07-19', **spatial) assert ds.time is not None ds2, _ = gen_tiff_dataset(bands, tmpdir, prefix='ds2-', timestamp='2018-07-19', **spatial) assert ds.time is not None assert ds.time == ds2.time sources = Datacube.group_datasets([ds, ds2], 'time') progress_call_data = [] def progress_cbk(n, nt): progress_call_data.append((n, nt)) ds_data = Datacube.load_data(sources, gbox, ds.type.measurements, progress_cbk=progress_cbk) assert progress_call_data == [(1, 4), (2, 4), (3, 4), (4, 4)] np.testing.assert_array_equal(aa, ds_data.aa.values[0]) np.testing.assert_array_equal(aa, ds_data.bb.values[0]) def progress_cbk_fail_early(n, nt): progress_call_data.append((n, nt)) raise TerminateCurrentLoad() def progress_cbk_fail_early2(n, nt): progress_call_data.append((n, nt)) if n > 1: raise KeyboardInterrupt() progress_call_data = [] ds_data = Datacube.load_data(sources, gbox, ds.type.measurements, progress_cbk=progress_cbk_fail_early) assert progress_call_data == [(1, 4)] assert ds_data.dc_partial_load is True np.testing.assert_array_equal(aa, ds_data.aa.values[0]) np.testing.assert_array_equal(nodata, ds_data.bb.values[0]) progress_call_data = [] ds_data = Datacube.load_data(sources, gbox, ds.type.measurements, progress_cbk=progress_cbk_fail_early2) assert ds_data.dc_partial_load is True assert progress_call_data == [(1, 4), (2, 4)]
def fetch(self, grouped, product_definitions, **load_settings): # type: (DatasetPile, Dict[str, Dict], Dict[str, Any]) -> xarray.Dataset """ Convert grouped datasets to `xarray.Dataset`. """ # TODO: provide `load_lazy` and `load_strict` instead # validate data to be loaded _ = self.output_measurements(product_definitions) if 'product' in self: merged = merge_search_terms( select_keys(self, self._LOAD_KEYS), select_keys(load_settings, self._LOAD_KEYS)) # load_settings should not contain `measurements` measurements = list( self.output_measurements(product_definitions).values()) def unwrap(_, value): return value.pile return Datacube.load_data(grouped.map(unwrap).pile, grouped.geobox, measurements, fuse_func=merged.get('fuse_func'), dask_chunks=merged.get('dask_chunks'), use_threads=merged.get('use_threads')) elif 'transform' in self: return self._transformation.compute( self._input.fetch(grouped, product_definitions, **load_settings)) elif 'collate' in self: def is_from(source_index): def result(_, value): return value.pile[source_index] is not None return result def strip_source(_, value): for data in value.pile: if data is not None: return data raise ValueError("Every child of DatasetPile object is None") def fetch_child(child, source_index, r): size = reduce(lambda x, y: x * y, r.shape, 1) if size > 0: result = child.fetch(r, product_definitions, **load_settings) name = self.get('index_measurement_name') if name is None: return result # implication for dask? measurement = Measurement(name=name, dtype='int8', nodata=-1, units='1') shape = select_unique( [result[band].shape for band in result.data_vars]) array = numpy.full(shape, source_index, dtype=measurement.dtype) first = result[list(result.data_vars)[0]] result[name] = xarray.DataArray( array, dims=first.dims, coords=first.coords, name=name).assign_attrs(units=measurement.units, nodata=measurement.nodata) return result else: # empty raster return None groups = [ fetch_child( child, source_index, grouped.filter(is_from(source_index)).map(strip_source)) for source_index, child in enumerate(self._children) ] non_empty = [g for g in groups if g is not None] return xarray.concat(non_empty, dim='time').assign_attrs( **select_unique([g.attrs for g in non_empty])) elif 'juxtapose' in self: def select_child(source_index): def result(_, value): return value.pile[source_index] return result def fetch_recipe(source_index): child_groups = grouped.map(select_child(source_index)) return DatasetPile(child_groups.pile, grouped.geobox) groups = [ child.fetch(fetch_recipe(source_index), product_definitions, **load_settings) for source_index, child in enumerate(self._children) ] return xarray.merge(groups).assign_attrs( **select_unique([g.attrs for g in groups])) else: raise VirtualProductException("virtual product was not validated")
def dc_load( datasets: Sequence[Dataset], measurements: Optional[Union[str, Sequence[str]]] = None, geobox: Optional[GeoBox] = None, groupby: Optional[str] = None, resampling: Optional[Union[str, Dict[str, str]]] = None, skip_broken_datasets: bool = False, chunks: Optional[Dict[str, int]] = None, progress_cbk: Optional[Callable[[int, int], Any]] = None, fuse_func=None, **kw, ) -> xr.Dataset: assert len(datasets) > 0 # dask_chunks is a backward-compatibility alias for chunks if chunks is None: chunks = kw.pop("dask_chunks", None) # group_by is a backward-compatibility alias for groupby if groupby is None: groupby = kw.pop("group_by", "time") # bands alias for measurements if measurements is None: measurements = kw.pop("bands", None) # extract all "output_geobox" inputs geo_keys = { k: kw.pop(k) for k in [ "like", "geopolygon", "resolution", "output_crs", "crs", "align", "x", "y", "lat", "lon", ] if k in kw } ds = datasets[0] product = ds.type if geobox is None: geobox = output_geobox( grid_spec=product.grid_spec, load_hints=product.load_hints(), **geo_keys, datasets=datasets, ) elif len(geo_keys): warn(f"Supplied 'geobox=' parameter aliases {list(geo_keys)} inputs") grouped = Datacube.group_datasets(datasets, groupby) mm = product.lookup_measurements(measurements) return Datacube.load_data( grouped, geobox, mm, resampling=resampling, fuse_func=fuse_func, dask_chunks=chunks, skip_broken_datasets=skip_broken_datasets, progress_cbk=progress_cbk, **kw, )