Ejemplo n.º 1
0
def test_load_data(tmpdir):
    tmpdir = Path(str(tmpdir))

    group_by = query_group_by('time')
    spatial = dict(resolution=(15, -15),
                   offset=(11230, 1381110),)

    nodata = -999
    aa = mk_test_image(96, 64, 'int16', nodata=nodata)

    ds, gbox = gen_tiff_dataset([SimpleNamespace(name='aa', values=aa, nodata=nodata)],
                                tmpdir,
                                prefix='ds1-',
                                timestamp='2018-07-19',
                                **spatial)
    assert ds.time is not None

    ds2, _ = gen_tiff_dataset([SimpleNamespace(name='aa', values=aa, nodata=nodata)],
                              tmpdir,
                              prefix='ds2-',
                              timestamp='2018-07-19',
                              **spatial)
    assert ds.time is not None
    assert ds.time == ds2.time

    sources = Datacube.group_datasets([ds], 'time')
    sources2 = Datacube.group_datasets([ds, ds2], group_by)

    mm = ['aa']
    mm = [ds.type.measurements[k] for k in mm]

    ds_data = Datacube.load_data(sources, gbox, mm)
    assert ds_data.aa.nodata == nodata
    np.testing.assert_array_equal(aa, ds_data.aa.values[0])

    custom_fuser_call_count = 0

    def custom_fuser(dest, delta):
        nonlocal custom_fuser_call_count
        custom_fuser_call_count += 1
        dest[:] += delta

    progress_call_data = []

    def progress_cbk(n, nt):
        progress_call_data.append((n, nt))

    ds_data = Datacube.load_data(sources2, gbox, mm, fuse_func=custom_fuser,
                                 progress_cbk=progress_cbk)
    assert ds_data.aa.nodata == nodata
    assert custom_fuser_call_count > 0
    np.testing.assert_array_equal(nodata + aa + aa, ds_data.aa.values[0])

    assert progress_call_data == [(1, 2), (2, 2)]
Ejemplo n.º 2
0
def check_data_with_api(index, time_slices):
    """Chek retrieved data for specific values.

    We scale down by 100 and check for predefined values in the
    corners.
    """
    from datacube import Datacube
    dc = Datacube(index=index)

    # Make the retrieved data 100 less granular
    shape_x = int(GEOTIFF['shape']['x'] / 100.0)
    shape_y = int(GEOTIFF['shape']['y'] / 100.0)
    pixel_x = int(GEOTIFF['pixel_size']['x'] * 100)
    pixel_y = int(GEOTIFF['pixel_size']['y'] * 100)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)
    geobox = geometry.GeoBox(
        shape_x + 1, shape_y + 1,
        Affine(pixel_x, 0.0, GEOTIFF['ul']['x'], 0.0, pixel_y,
               GEOTIFF['ul']['y']), geometry.CRS(GEOTIFF['crs']))
    observations = dc.find_datasets(product='ls5_nbar_albers',
                                    geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values())
    assert hashlib.md5(
        data.green.data).hexdigest() == '7f5ace486e88d33edf3512e8de6b6996'
    assert hashlib.md5(
        data.blue.data).hexdigest() == 'b58204f1e10dd678b292df188c242c7e'
    for time_slice in range(time_slices):
        assert data.blue.values[time_slice][-1, -1] == -999
Ejemplo n.º 3
0
def native_load(ds, measurements=None, basis=None, **kw):
    """Load single dataset in native resolution.

    :param ds: Dataset
    :param measurements: List of band names to load
    :param basis: Name of the band to use for computing reference frame, other
    bands might be reprojected if they use different pixel grid

    :param **kw: Any other parameter load_data accepts

    :return: Xarray dataset
    """
    from datacube import Datacube
    geobox = native_geobox(
        ds, measurements,
        basis)  # early exit via exception if no compatible grid exists
    if measurements is not None:
        mm = [ds.type.measurements[n] for n in measurements]
    else:
        mm = ds.type.measurements

    return Datacube.load_data(Datacube.group_datasets([ds], 'time'),
                              geobox,
                              measurements=mm,
                              **kw)
Ejemplo n.º 4
0
    def fetch(self, grouped: VirtualDatasetBox,
              **load_settings: Dict[str, Any]) -> xarray.Dataset:
        """ Convert grouped datasets to `xarray.Dataset`. """
        merged = merge_search_terms(
            select_keys(self, self._LOAD_KEYS),
            select_keys(load_settings, self._LOAD_KEYS))

        if 'measurements' not in self:
            measurements = load_settings.get('measurements')
        elif 'measurements' not in load_settings:
            measurements = self.get('measurements')
        else:
            for measurement in load_settings['measurements']:
                self._assert(
                    measurement in self['measurements'],
                    '{} not found in {}'.format(measurement, self._product))

            measurements = load_settings['measurement']

        measurements = self.output_measurements(grouped.product_definitions)

        result = Datacube.load_data(grouped.pile,
                                    grouped.geobox,
                                    list(measurements.values()),
                                    fuse_func=merged.get('fuse_func'),
                                    dask_chunks=merged.get('dask_chunks'))

        return apply_aliases(result,
                             grouped.product_definitions[self._product],
                             list(measurements))
Ejemplo n.º 5
0
    def fetch(self, grouped: VirtualDatasetBox,
              **load_settings: Dict[str, Any]) -> xarray.Dataset:
        """ Convert grouped datasets to `xarray.Dataset`. """

        load_keys = self._LOAD_KEYS - {'measurements'}
        merged = merge_search_terms(select_keys(self, load_keys),
                                    select_keys(load_settings, load_keys))

        product = grouped.product_definitions[self._product]

        if 'measurements' in self and 'measurements' in load_settings:
            for measurement in load_settings['measurements']:
                self._assert(
                    measurement in self['measurements'],
                    '{} not found in {}'.format(measurement, self._product))

        measurement_dicts = self.output_measurements(
            grouped.product_definitions, load_settings.get('measurements'))

        if grouped.load_natively:
            canonical_names = [
                product.canonical_measurement(measurement)
                for measurement in measurement_dicts
            ]
            dataset_geobox = geobox_union_conservative([
                native_geobox(ds,
                              measurements=canonical_names,
                              basis=merged.get('like'))
                for ds in grouped.box.sum().item()
            ])

            if grouped.geopolygon is not None:
                reproject_roi = compute_reproject_roi(
                    dataset_geobox,
                    GeoBox.from_geopolygon(
                        grouped.geopolygon,
                        crs=dataset_geobox.crs,
                        align=dataset_geobox.alignment,
                        resolution=dataset_geobox.resolution))

                self._assert(reproject_roi.is_st,
                             "native load is not axis-aligned")
                self._assert(numpy.isclose(reproject_roi.scale, 1.0),
                             "native load should not require scaling")

                geobox = dataset_geobox[reproject_roi.roi_src]
            else:
                geobox = dataset_geobox
        else:
            geobox = grouped.geobox

        result = Datacube.load_data(grouped.box,
                                    geobox,
                                    list(measurement_dicts.values()),
                                    fuse_func=merged.get('fuse_func'),
                                    dask_chunks=merged.get('dask_chunks'),
                                    resampling=merged.get(
                                        'resampling', 'nearest'))

        return result
Ejemplo n.º 6
0
def check_open_with_api(index, time_slices):
    with rasterio.Env():
        from datacube import Datacube
        dc = Datacube(index=index)

        input_type_name = 'ls5_nbar_albers'
        input_type = dc.index.products.get_by_name(input_type_name)
        geobox = geometry.GeoBox(200, 200, Affine(25, 0.0, 638000, 0.0, -25, 6276000), geometry.CRS('EPSG:28355'))
        observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent)
        group_by = query_group_by('time')
        sources = dc.group_datasets(observations, group_by)
        data = dc.load_data(sources, geobox, input_type.measurements.values())
        assert data.blue.shape == (time_slices, 200, 200)

        chunk_profile = {'time': 1, 'x': 100, 'y': 100}
        lazy_data = dc.load_data(sources, geobox, input_type.measurements.values(), dask_chunks=chunk_profile)
        assert lazy_data.blue.shape == (time_slices, 200, 200)
        assert (lazy_data.blue.load() == data.blue).all()
Ejemplo n.º 7
0
def check_open_with_api(driver_manager, time_slices):
    from datacube import Datacube
    dc = Datacube(driver_manager=driver_manager)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)
    geobox = geometry.GeoBox(200, 200, Affine(25, 0.0, 638000, 0.0, -25, 6276000), geometry.CRS('EPSG:28355'))
    observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values(), driver_manager=driver_manager)
    assert data.blue.shape == (time_slices, 200, 200)
Ejemplo n.º 8
0
def check_open_with_api(index):
    from datacube import Datacube
    dc = Datacube(index=index)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)

    geobox = GeoBox(200, 200, Affine(25, 0.0, 1500000, 0.0, -25, -3900000),
                    CRS('EPSG:3577'))
    observations = dc.find_datasets(product='ls5_nbar_albers',
                                    geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values())
    assert data.blue.shape == (1, 200, 200)
Ejemplo n.º 9
0
        def _product_fetch_():
            merged = merge_search_terms(
                select_keys(self, self._LOAD_KEYS),
                select_keys(load_settings, self._LOAD_KEYS))

            # load_settings should not contain `measurements` for now
            measurements = self.output_measurements(product_definitions)

            result = Datacube.load_data(grouped.pile,
                                        grouped.geobox,
                                        list(measurements.values()),
                                        fuse_func=merged.get('fuse_func'),
                                        dask_chunks=merged.get('dask_chunks'))

            return apply_aliases(result, product_definitions[self._product],
                                 list(measurements))
Ejemplo n.º 10
0
def check_data_with_api(index, time_slices):
    """Chek retrieved data for specific values.

    We scale down by 100 and check for predefined values in the
    corners.
    """
    from datacube import Datacube
    dc = Datacube(index=index)

    # TODO: this test needs to change, it tests that results are exactly the
    #       same as some time before, but with the current zoom out factor it's
    #       hard to verify that results are as expected even with human
    #       judgement. What it should test is that reading native from the
    #       ingested product gives exactly the same results as reading into the
    #       same GeoBox from the original product. Separate to that there
    #       should be a read test that confirms that what you read from native
    #       product while changing projection is of expected value

    # Make the retrieved data lower res
    ss = 100
    shape_x = int(GEOTIFF['shape']['x'] / ss)
    shape_y = int(GEOTIFF['shape']['y'] / ss)
    pixel_x = int(GEOTIFF['pixel_size']['x'] * ss)
    pixel_y = int(GEOTIFF['pixel_size']['y'] * ss)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)
    geobox = geometry.GeoBox(
        shape_x + 2, shape_y + 2,
        Affine(pixel_x, 0.0, GEOTIFF['ul']['x'], 0.0, pixel_y,
               GEOTIFF['ul']['y']), geometry.CRS(GEOTIFF['crs']))
    observations = dc.find_datasets(product='ls5_nbar_albers',
                                    geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values())
    assert hashlib.md5(
        data.green.data).hexdigest() == '0f64647bad54db4389fb065b2128025e'
    assert hashlib.md5(
        data.blue.data).hexdigest() == '41a7b50dfe5c4c1a1befbc378225beeb'
    for time_slice in range(time_slices):
        assert data.blue.values[time_slice][-1, -1] == -999
Ejemplo n.º 11
0
def _load_with_native_transform_1(
    sources: xr.DataArray,
    bands: Tuple[str, ...],
    geobox: GeoBox,
    native_transform: Callable[[xr.Dataset], xr.Dataset],
    basis: Optional[str] = None,
    groupby: Optional[str] = None,
    fuser: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
    resampling: str = "nearest",
    chunks: Optional[Dict[str, int]] = None,
    load_chunks: Optional[Dict[str, int]] = None,
    pad: Optional[int] = None,
) -> xr.Dataset:
    if basis is None:
        basis = bands[0]

    if load_chunks is None:
        load_chunks = chunks

    (ds, ) = sources.data[0]
    load_geobox = compute_native_load_geobox(geobox, ds, basis)
    if pad is not None:
        load_geobox = gbox.pad(load_geobox, pad)

    mm = ds.type.lookup_measurements(bands)
    xx = Datacube.load_data(sources, load_geobox, mm, dask_chunks=load_chunks)
    xx = native_transform(xx)

    if groupby is not None:
        if fuser is None:
            fuser = _nodata_fuser  # type: ignore
        xx = xx.groupby(groupby).map(fuser)

    _chunks = None
    if chunks is not None:
        _chunks = tuple(chunks.get(ax, -1) for ax in ("y", "x"))

    return xr_reproject(xx, geobox, chunks=_chunks,
                        resampling=resampling)  # type: ignore
Ejemplo n.º 12
0
def test_load_data_cbk(tmpdir):
    from datacube.api import TerminateCurrentLoad

    tmpdir = Path(str(tmpdir))

    spatial = dict(
        resolution=(15, -15),
        offset=(11230, 1381110),
    )

    nodata = -999
    aa = mk_test_image(96, 64, 'int16', nodata=nodata)

    bands = [
        SimpleNamespace(name=name, values=aa, nodata=nodata)
        for name in ['aa', 'bb']
    ]

    ds, gbox = gen_tiff_dataset(bands,
                                tmpdir,
                                prefix='ds1-',
                                timestamp='2018-07-19',
                                **spatial)
    assert ds.time is not None

    ds2, _ = gen_tiff_dataset(bands,
                              tmpdir,
                              prefix='ds2-',
                              timestamp='2018-07-19',
                              **spatial)
    assert ds.time is not None
    assert ds.time == ds2.time

    sources = Datacube.group_datasets([ds, ds2], 'time')
    progress_call_data = []

    def progress_cbk(n, nt):
        progress_call_data.append((n, nt))

    ds_data = Datacube.load_data(sources,
                                 gbox,
                                 ds.type.measurements,
                                 progress_cbk=progress_cbk)

    assert progress_call_data == [(1, 4), (2, 4), (3, 4), (4, 4)]
    np.testing.assert_array_equal(aa, ds_data.aa.values[0])
    np.testing.assert_array_equal(aa, ds_data.bb.values[0])

    def progress_cbk_fail_early(n, nt):
        progress_call_data.append((n, nt))
        raise TerminateCurrentLoad()

    def progress_cbk_fail_early2(n, nt):
        progress_call_data.append((n, nt))
        if n > 1:
            raise KeyboardInterrupt()

    progress_call_data = []
    ds_data = Datacube.load_data(sources,
                                 gbox,
                                 ds.type.measurements,
                                 progress_cbk=progress_cbk_fail_early)

    assert progress_call_data == [(1, 4)]
    assert ds_data.dc_partial_load is True
    np.testing.assert_array_equal(aa, ds_data.aa.values[0])
    np.testing.assert_array_equal(nodata, ds_data.bb.values[0])

    progress_call_data = []
    ds_data = Datacube.load_data(sources,
                                 gbox,
                                 ds.type.measurements,
                                 progress_cbk=progress_cbk_fail_early2)

    assert ds_data.dc_partial_load is True
    assert progress_call_data == [(1, 4), (2, 4)]
Ejemplo n.º 13
0
    def fetch(self, grouped, product_definitions, **load_settings):
        # type: (DatasetPile, Dict[str, Dict], Dict[str, Any]) -> xarray.Dataset
        """ Convert grouped datasets to `xarray.Dataset`. """
        # TODO: provide `load_lazy` and `load_strict` instead

        # validate data to be loaded
        _ = self.output_measurements(product_definitions)

        if 'product' in self:
            merged = merge_search_terms(
                select_keys(self, self._LOAD_KEYS),
                select_keys(load_settings, self._LOAD_KEYS))

            # load_settings should not contain `measurements`
            measurements = list(
                self.output_measurements(product_definitions).values())

            def unwrap(_, value):
                return value.pile

            return Datacube.load_data(grouped.map(unwrap).pile,
                                      grouped.geobox,
                                      measurements,
                                      fuse_func=merged.get('fuse_func'),
                                      dask_chunks=merged.get('dask_chunks'),
                                      use_threads=merged.get('use_threads'))

        elif 'transform' in self:
            return self._transformation.compute(
                self._input.fetch(grouped, product_definitions,
                                  **load_settings))

        elif 'collate' in self:

            def is_from(source_index):
                def result(_, value):
                    return value.pile[source_index] is not None

                return result

            def strip_source(_, value):
                for data in value.pile:
                    if data is not None:
                        return data

                raise ValueError("Every child of DatasetPile object is None")

            def fetch_child(child, source_index, r):
                size = reduce(lambda x, y: x * y, r.shape, 1)

                if size > 0:
                    result = child.fetch(r, product_definitions,
                                         **load_settings)
                    name = self.get('index_measurement_name')

                    if name is None:
                        return result

                    # implication for dask?
                    measurement = Measurement(name=name,
                                              dtype='int8',
                                              nodata=-1,
                                              units='1')
                    shape = select_unique(
                        [result[band].shape for band in result.data_vars])
                    array = numpy.full(shape,
                                       source_index,
                                       dtype=measurement.dtype)
                    first = result[list(result.data_vars)[0]]
                    result[name] = xarray.DataArray(
                        array, dims=first.dims, coords=first.coords,
                        name=name).assign_attrs(units=measurement.units,
                                                nodata=measurement.nodata)
                    return result
                else:
                    # empty raster
                    return None

            groups = [
                fetch_child(
                    child, source_index,
                    grouped.filter(is_from(source_index)).map(strip_source))
                for source_index, child in enumerate(self._children)
            ]

            non_empty = [g for g in groups if g is not None]

            return xarray.concat(non_empty, dim='time').assign_attrs(
                **select_unique([g.attrs for g in non_empty]))

        elif 'juxtapose' in self:

            def select_child(source_index):
                def result(_, value):
                    return value.pile[source_index]

                return result

            def fetch_recipe(source_index):
                child_groups = grouped.map(select_child(source_index))
                return DatasetPile(child_groups.pile, grouped.geobox)

            groups = [
                child.fetch(fetch_recipe(source_index), product_definitions,
                            **load_settings)
                for source_index, child in enumerate(self._children)
            ]

            return xarray.merge(groups).assign_attrs(
                **select_unique([g.attrs for g in groups]))

        else:
            raise VirtualProductException("virtual product was not validated")
Ejemplo n.º 14
0
def dc_load(
    datasets: Sequence[Dataset],
    measurements: Optional[Union[str, Sequence[str]]] = None,
    geobox: Optional[GeoBox] = None,
    groupby: Optional[str] = None,
    resampling: Optional[Union[str, Dict[str, str]]] = None,
    skip_broken_datasets: bool = False,
    chunks: Optional[Dict[str, int]] = None,
    progress_cbk: Optional[Callable[[int, int], Any]] = None,
    fuse_func=None,
    **kw,
) -> xr.Dataset:
    assert len(datasets) > 0

    # dask_chunks is a backward-compatibility alias for chunks
    if chunks is None:
        chunks = kw.pop("dask_chunks", None)
    # group_by is a backward-compatibility alias for groupby
    if groupby is None:
        groupby = kw.pop("group_by", "time")
    # bands alias for measurements
    if measurements is None:
        measurements = kw.pop("bands", None)

    # extract all "output_geobox" inputs
    geo_keys = {
        k: kw.pop(k)
        for k in [
            "like",
            "geopolygon",
            "resolution",
            "output_crs",
            "crs",
            "align",
            "x",
            "y",
            "lat",
            "lon",
        ] if k in kw
    }

    ds = datasets[0]
    product = ds.type

    if geobox is None:
        geobox = output_geobox(
            grid_spec=product.grid_spec,
            load_hints=product.load_hints(),
            **geo_keys,
            datasets=datasets,
        )
    elif len(geo_keys):
        warn(f"Supplied 'geobox=' parameter aliases {list(geo_keys)} inputs")

    grouped = Datacube.group_datasets(datasets, groupby)
    mm = product.lookup_measurements(measurements)
    return Datacube.load_data(
        grouped,
        geobox,
        mm,
        resampling=resampling,
        fuse_func=fuse_func,
        dask_chunks=chunks,
        skip_broken_datasets=skip_broken_datasets,
        progress_cbk=progress_cbk,
        **kw,
    )