Example #1
0
def test_load_data(tmpdir):
    tmpdir = Path(str(tmpdir))

    group_by = query_group_by('time')
    spatial = dict(resolution=(15, -15),
                   offset=(11230, 1381110),)

    nodata = -999
    aa = mk_test_image(96, 64, 'int16', nodata=nodata)

    ds, gbox = gen_tiff_dataset([SimpleNamespace(name='aa', values=aa, nodata=nodata)],
                                tmpdir,
                                prefix='ds1-',
                                timestamp='2018-07-19',
                                **spatial)
    assert ds.time is not None

    ds2, _ = gen_tiff_dataset([SimpleNamespace(name='aa', values=aa, nodata=nodata)],
                              tmpdir,
                              prefix='ds2-',
                              timestamp='2018-07-19',
                              **spatial)
    assert ds.time is not None
    assert ds.time == ds2.time

    sources = Datacube.group_datasets([ds], 'time')
    sources2 = Datacube.group_datasets([ds, ds2], group_by)

    mm = ['aa']
    mm = [ds.type.measurements[k] for k in mm]

    ds_data = Datacube.load_data(sources, gbox, mm)
    assert ds_data.aa.nodata == nodata
    np.testing.assert_array_equal(aa, ds_data.aa.values[0])

    custom_fuser_call_count = 0

    def custom_fuser(dest, delta):
        nonlocal custom_fuser_call_count
        custom_fuser_call_count += 1
        dest[:] += delta

    progress_call_data = []

    def progress_cbk(n, nt):
        progress_call_data.append((n, nt))

    ds_data = Datacube.load_data(sources2, gbox, mm, fuse_func=custom_fuser,
                                 progress_cbk=progress_cbk)
    assert ds_data.aa.nodata == nodata
    assert custom_fuser_call_count > 0
    np.testing.assert_array_equal(nodata + aa + aa, ds_data.aa.values[0])

    assert progress_call_data == [(1, 2), (2, 2)]
Example #2
0
    def group(self, datasets: VirtualDatasetBag, **group_settings: Dict[str, Any]) -> VirtualDatasetBox:
        geopolygon = datasets.geopolygon
        selected = list(datasets.bag)

        # geobox
        merged = merge_search_terms(self, group_settings)

        try:
            geobox = output_geobox(datasets=selected,
                                   grid_spec=datasets.product_definitions[self._product].grid_spec,
                                   geopolygon=geopolygon, **select_keys(merged, self._GEOBOX_KEYS))
            load_natively = False

        except ValueError:
            # we are not calculating geoboxes here for the moment
            # since it may require filesystem access
            # in ODC 2.0 the dataset should know the information required
            geobox = None
            load_natively = True

        # group by time
        group_query = query_group_by(**select_keys(merged, self._GROUPING_KEYS))

        # information needed for Datacube.load_data
        return VirtualDatasetBox(Datacube.group_datasets(selected, group_query),
                                 geobox,
                                 load_natively,
                                 datasets.product_definitions,
                                 geopolygon=None if not load_natively else geopolygon)
Example #3
0
    def __call__(self, index, product, time, group_by) -> Tile:
        # Do for a specific poly whose boundary is known
        output_crs = CRS(self.storage['crs'])
        filtered_items = [
            'geopolygon', 'lon', 'lat', 'longitude', 'latitude', 'x', 'y'
        ]
        filtered_dict = {
            k: v
            for k, v in self.input_region.items() if k in filtered_items
        }
        if self.feature is not None:
            filtered_dict['geopolygon'] = self.feature.geopolygon
            geopoly = filtered_dict['geopolygon']
        else:
            geopoly = query_geopolygon(**self.input_region)

        dc = Datacube(index=index)
        datasets = dc.find_datasets(product=product,
                                    time=time,
                                    group_by=group_by,
                                    **filtered_dict)
        group_by = query_group_by(group_by=group_by)
        sources = dc.group_datasets(datasets, group_by)
        output_resolution = [
            self.storage['resolution'][dim] for dim in output_crs.dimensions
        ]
        geopoly = geopoly.to_crs(output_crs)
        geobox = GeoBox.from_geopolygon(geopoly, resolution=output_resolution)

        return Tile(sources, geobox)
Example #4
0
def native_load(ds, measurements=None, basis=None, **kw):
    """Load single dataset in native resolution.

    :param ds: Dataset
    :param measurements: List of band names to load
    :param basis: Name of the band to use for computing reference frame, other
    bands might be reprojected if they use different pixel grid

    :param **kw: Any other parameter load_data accepts

    :return: Xarray dataset
    """
    from datacube import Datacube
    geobox = native_geobox(
        ds, measurements,
        basis)  # early exit via exception if no compatible grid exists
    if measurements is not None:
        mm = [ds.type.measurements[n] for n in measurements]
    else:
        mm = ds.type.measurements

    return Datacube.load_data(Datacube.group_datasets([ds], 'time'),
                              geobox,
                              measurements=mm,
                              **kw)
Example #5
0
def test_grouping_datasets():
    def group_func(d):
        return d.time

    dimension = 'time'
    units = None
    datasets = [
        SimpleNamespace(time=datetime.datetime(2016, 1, 1),
                        value='foo',
                        id=UUID(int=10)),
        SimpleNamespace(time=datetime.datetime(2016, 2, 1),
                        value='bar',
                        id=UUID(int=1)),
        SimpleNamespace(time=datetime.datetime(2016, 1, 1),
                        value='flim',
                        id=UUID(int=9)),
    ]

    group_by = GroupBy(dimension, group_func, units, sort_key=group_func)
    grouped = Datacube.group_datasets(datasets, group_by)
    dss = grouped.isel(time=0).values[()]
    assert isinstance(dss, tuple)
    assert len(dss) == 2
    assert [ds.value for ds in dss] == ['flim', 'foo']

    dss = grouped.isel(time=1).values[()]
    assert isinstance(dss, tuple)
    assert len(dss) == 1
    assert [ds.value for ds in dss] == ['bar']

    assert str(grouped.time.dtype) == 'datetime64[ns]'
    assert grouped.loc['2016-01-01':'2016-01-15']
Example #6
0
def check_data_with_api(index, time_slices):
    """Chek retrieved data for specific values.

    We scale down by 100 and check for predefined values in the
    corners.
    """
    from datacube import Datacube
    dc = Datacube(index=index)

    # Make the retrieved data 100 less granular
    shape_x = int(GEOTIFF['shape']['x'] / 100.0)
    shape_y = int(GEOTIFF['shape']['y'] / 100.0)
    pixel_x = int(GEOTIFF['pixel_size']['x'] * 100)
    pixel_y = int(GEOTIFF['pixel_size']['y'] * 100)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)
    geobox = geometry.GeoBox(
        shape_x + 1, shape_y + 1,
        Affine(pixel_x, 0.0, GEOTIFF['ul']['x'], 0.0, pixel_y,
               GEOTIFF['ul']['y']), geometry.CRS(GEOTIFF['crs']))
    observations = dc.find_datasets(product='ls5_nbar_albers',
                                    geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values())
    assert hashlib.md5(
        data.green.data).hexdigest() == '7f5ace486e88d33edf3512e8de6b6996'
    assert hashlib.md5(
        data.blue.data).hexdigest() == 'b58204f1e10dd678b292df188c242c7e'
    for time_slice in range(time_slices):
        assert data.blue.values[time_slice][-1, -1] == -999
Example #7
0
        def _product_group_():
            # select only those inside the ROI
            # ROI could be smaller than the query for the `query` method

            if query_geopolygon(**search_terms) is not None:
                geopolygon = query_geopolygon(**search_terms)
                selected = list(
                    select_datasets_inside_polygon(datasets.pile, geopolygon))
            else:
                geopolygon = datasets.geopolygon
                selected = list(datasets.pile)

            # geobox
            merged = merge_search_terms(
                select_keys(self, self._NON_SPATIAL_KEYS),
                select_keys(search_terms, self._NON_SPATIAL_KEYS))

            geobox = output_geobox(datasets=selected,
                                   grid_spec=datasets.grid_spec,
                                   geopolygon=geopolygon,
                                   **select_keys(merged, self._GEOBOX_KEYS))

            # group by time
            group_query = query_group_by(
                **select_keys(merged, self._GROUPING_KEYS))

            # information needed for Datacube.load_data
            return VirtualDatasetBox(
                Datacube.group_datasets(selected, group_query), geobox,
                datasets.product_definitions)
Example #8
0
def _do_fc_task(config, task):
    """
    Load data, run FC algorithm, attach metadata, and write output.
    :param dict config: Config object
    :param dict task: Dictionary of tasks
    :return: Dataset objects representing the generated data that can be added to the index
    :rtype: list(datacube.model.Dataset)
    """
    global_attributes = config['global_attributes']
    variable_params = config['variable_params']
    output_product = config['fc_product']

    file_path = Path(task['filename_dataset'])

    uri, band_uris = calc_uris(file_path, variable_params)
    output_measurements = config['fc_product'].measurements.values()

    nbart = io.native_load(task['dataset'], measurements=config['load_bands'])
    if config['band_mapping'] is not None:
        nbart = nbart.rename(config['band_mapping'])

    fc_dataset = run_fc(nbart, output_measurements,
                        config.get('sensor_regression_coefficients'))

    def _make_dataset(labels, sources):
        assert sources
        dataset = make_dataset(product=output_product,
                               sources=sources,
                               extent=nbart.geobox.extent,
                               center_time=labels['time'],
                               uri=uri,
                               band_uris=band_uris,
                               app_info=_get_app_metadata(config),
                               valid_data=polygon_from_sources_extents(
                                   sources, nbart.geobox))
        return dataset

    source = Datacube.group_datasets([task['dataset']], 'time')

    datasets = xr_apply(source, _make_dataset, dtype='O')
    fc_dataset['dataset'] = datasets_to_doc(datasets)

    base, ext = os.path.splitext(file_path)
    if ext == '.tif':
        dataset_to_geotif_yaml(
            dataset=fc_dataset,
            odc_dataset=datasets.item(),
            filename=file_path,
            variable_params=variable_params,
        )
    else:
        write_dataset_to_netcdf(
            dataset=fc_dataset,
            filename=file_path,
            global_attributes=global_attributes,
            variable_params=variable_params,
        )

    return datasets
Example #9
0
def _group_datasets_by_date(datasets):
    def group_func(d):
        return d['time'].date()

    def sort_key(d):
        return d['time']
    dimension = 'time'
    units = None

    group_by = GroupBy(dimension, group_func, units, sort_key)
    return Datacube.group_datasets(datasets, group_by)
def check_open_with_api(driver_manager, time_slices):
    from datacube import Datacube
    dc = Datacube(driver_manager=driver_manager)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)
    geobox = geometry.GeoBox(200, 200, Affine(25, 0.0, 638000, 0.0, -25, 6276000), geometry.CRS('EPSG:28355'))
    observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values(), driver_manager=driver_manager)
    assert data.blue.shape == (time_slices, 200, 200)
Example #11
0
    def __call__(self, tile_idx, _y=None):
        from datacube import Datacube
        from datacube.api.grid_workflow import Tile

        if _y is not None:
            tile_idx = (tile_idx, _y)

        k = self._key_fmt.format(*tile_idx)
        dss = list(self._cache.stream_group(k))

        geobox = self._grid_spec.tile_geobox(tile_idx)
        sources = Datacube.group_datasets(dss, self._grouper)
        return Tile(sources, geobox)
Example #12
0
    def __call__(self, tile_idx, _y=None, group_by=None):
        if _y is not None:
            tile_idx = (tile_idx, _y)

        if group_by is None:
            group_by = self._default_groupby

        dss = list(self._cache.stream_grid_tile(tile_idx, grid=self._grid))
        if group_by == "nothing":
            sources = group_by_nothing(dss)
        else:
            sources = Datacube.group_datasets(dss, group_by)

        geobox = self._gs.tile_geobox(tile_idx)
        return Tile(sources, geobox)
Example #13
0
def test_group_datasets_by_time():
    bands = [dict(name='a')]
    # Same time instant but one explicitly marked as UTC
    ds1 = mk_sample_dataset(bands, timestamp="2019-01-01T23:24:00Z")
    ds2 = mk_sample_dataset(bands, timestamp="2019-01-01T23:24:00")
    # Same "time" but in a different timezone, and actually later
    ds3 = mk_sample_dataset(bands, timestamp="2019-01-01T23:24:00-1")
    assert ds1.center_time.tzinfo is not None
    assert ds2.center_time.tzinfo is None
    assert ds3.center_time.tzinfo is not None

    xx = Datacube.group_datasets([ds1, ds2, ds3], 'time')
    assert xx.time.shape == (2, )
    assert len(xx.data[0]) == 2
    assert len(xx.data[1]) == 1
Example #14
0
def check_open_with_api(index):
    from datacube import Datacube
    dc = Datacube(index=index)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)

    geobox = GeoBox(200, 200, Affine(25, 0.0, 1500000, 0.0, -25, -3900000),
                    CRS('EPSG:3577'))
    observations = dc.find_datasets(product='ls5_nbar_albers',
                                    geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values())
    assert data.blue.shape == (1, 200, 200)
Example #15
0
class ArbitraryTileMaker(object):
    """
    Create a :class:`Tile` which can be used by :class:`GridWorkflow` to later load the required data.

    :param input_region: dictionary of spatial limits for searching for datasets. eg:
            geopolygon
            lat, lon boundaries

    """
    def __init__(self, index, input_region, storage):
        self.dc = Datacube(index=index)
        self.input_region = input_region
        self.storage = storage

    def __call__(self, product, time, group_by) -> Tile:
        # Do for a specific poly whose boundary is known
        output_crs = CRS(self.storage['crs'])
        filtered_item = [
            'geopolygon', 'lon', 'lat', 'longitude', 'latitude', 'x', 'y'
        ]
        filtered_dict = {
            k: v
            for k, v in filter(lambda t: t[0] in filtered_item,
                               self.input_region.items())
        }
        if 'feature_id' in self.input_region:
            filtered_dict['geopolygon'] = Geometry(
                self.input_region['geom_feat'],
                CRS(self.input_region['crs_txt']))
            geopoly = filtered_dict['geopolygon']
        else:
            geopoly = query_geopolygon(**self.input_region)
        datasets = self.dc.find_datasets(product=product,
                                         time=time,
                                         group_by=group_by,
                                         **filtered_dict)
        group_by = query_group_by(group_by=group_by)
        sources = self.dc.group_datasets(datasets, group_by)
        output_resolution = [
            self.storage['resolution'][dim] for dim in output_crs.dimensions
        ]
        geopoly = geopoly.to_crs(output_crs)
        geobox = GeoBox.from_geopolygon(geopoly, resolution=output_resolution)

        return Tile(sources, geobox)
def check_open_with_api(index, time_slices):
    with rasterio.Env():
        from datacube import Datacube
        dc = Datacube(index=index)

        input_type_name = 'ls5_nbar_albers'
        input_type = dc.index.products.get_by_name(input_type_name)
        geobox = geometry.GeoBox(200, 200, Affine(25, 0.0, 638000, 0.0, -25, 6276000), geometry.CRS('EPSG:28355'))
        observations = dc.find_datasets(product='ls5_nbar_albers', geopolygon=geobox.extent)
        group_by = query_group_by('time')
        sources = dc.group_datasets(observations, group_by)
        data = dc.load_data(sources, geobox, input_type.measurements.values())
        assert data.blue.shape == (time_slices, 200, 200)

        chunk_profile = {'time': 1, 'x': 100, 'y': 100}
        lazy_data = dc.load_data(sources, geobox, input_type.measurements.values(), dask_chunks=chunk_profile)
        assert lazy_data.blue.shape == (time_slices, 200, 200)
        assert (lazy_data.blue.load() == data.blue).all()
def check_data_with_api(index, time_slices):
    """Chek retrieved data for specific values.

    We scale down by 100 and check for predefined values in the
    corners.
    """
    from datacube import Datacube
    dc = Datacube(index=index)

    # TODO: this test needs to change, it tests that results are exactly the
    #       same as some time before, but with the current zoom out factor it's
    #       hard to verify that results are as expected even with human
    #       judgement. What it should test is that reading native from the
    #       ingested product gives exactly the same results as reading into the
    #       same GeoBox from the original product. Separate to that there
    #       should be a read test that confirms that what you read from native
    #       product while changing projection is of expected value

    # Make the retrieved data lower res
    ss = 100
    shape_x = int(GEOTIFF['shape']['x'] / ss)
    shape_y = int(GEOTIFF['shape']['y'] / ss)
    pixel_x = int(GEOTIFF['pixel_size']['x'] * ss)
    pixel_y = int(GEOTIFF['pixel_size']['y'] * ss)

    input_type_name = 'ls5_nbar_albers'
    input_type = dc.index.products.get_by_name(input_type_name)
    geobox = geometry.GeoBox(
        shape_x + 2, shape_y + 2,
        Affine(pixel_x, 0.0, GEOTIFF['ul']['x'], 0.0, pixel_y,
               GEOTIFF['ul']['y']), geometry.CRS(GEOTIFF['crs']))
    observations = dc.find_datasets(product='ls5_nbar_albers',
                                    geopolygon=geobox.extent)
    group_by = query_group_by('time')
    sources = dc.group_datasets(observations, group_by)
    data = dc.load_data(sources, geobox, input_type.measurements.values())
    assert hashlib.md5(
        data.green.data).hexdigest() == '0f64647bad54db4389fb065b2128025e'
    assert hashlib.md5(
        data.blue.data).hexdigest() == '41a7b50dfe5c4c1a1befbc378225beeb'
    for time_slice in range(time_slices):
        assert data.blue.values[time_slice][-1, -1] == -999
Example #18
0
    def group(self, datasets: VirtualDatasetBag,
              **search_terms: Dict[str, Any]) -> VirtualDatasetBox:
        geopolygon = datasets.geopolygon
        selected = list(datasets.pile)

        # geobox
        merged = merge_search_terms(self, search_terms)

        geobox = output_geobox(
            datasets=selected,
            grid_spec=datasets.product_definitions[self._product].grid_spec,
            geopolygon=geopolygon,
            **select_keys(merged, self._GEOBOX_KEYS))

        # group by time
        group_query = query_group_by(
            **select_keys(merged, self._GROUPING_KEYS))

        # information needed for Datacube.load_data
        return VirtualDatasetBox(
            Datacube.group_datasets(selected, group_query), geobox,
            datasets.product_definitions)
Example #19
0
    def group(self, datasets: VirtualDatasetBag,
              **search_terms: Dict[str, Any]) -> VirtualDatasetBox:
        """
        Datasets grouped by their timestamps.
        :param datasets: the `VirtualDatasetBag` to fetch data from
        :param query: to specify a spatial sub-region
        """
        grid_spec = datasets.grid_spec
        geopolygon = datasets.geopolygon

        if 'product' in self:
            # select only those inside the ROI
            # ROI could be smaller than the query for the `query` method
            if query_geopolygon(**search_terms) is not None:
                geopolygon = query_geopolygon(**search_terms)
                selected = list(
                    select_datasets_inside_polygon(datasets.pile, geopolygon))
            else:
                selected = list(datasets.pile)

            # geobox
            merged = merge_search_terms(
                select_keys(self, self._NON_SPATIAL_KEYS),
                select_keys(search_terms, self._NON_SPATIAL_KEYS))

            geobox = output_geobox(datasets=selected,
                                   grid_spec=grid_spec,
                                   geopolygon=geopolygon,
                                   **select_keys(merged, self._GEOBOX_KEYS))

            # group by time
            group_query = query_group_by(
                **select_keys(merged, self._GROUPING_KEYS))

            # information needed for Datacube.load_data
            return VirtualDatasetBox(
                Datacube.group_datasets(selected, group_query), geobox,
                datasets.product_definitions)

        elif 'transform' in self:
            return self._input.group(datasets, **search_terms)

        elif 'collate' in self:
            self._assert(
                'collate' in datasets.pile
                and len(datasets.pile['collate']) == len(self._children),
                "invalid dataset pile")

            def build(source_index, product, dataset_pile):
                grouped = product.group(
                    VirtualDatasetBag(dataset_pile, datasets.grid_spec,
                                      datasets.geopolygon,
                                      datasets.product_definitions),
                    **search_terms)

                def tag(_, value):
                    return {'collate': (source_index, value)}

                return grouped.map(tag)

            groups = [
                build(source_index, product, dataset_pile)
                for source_index, (product, dataset_pile) in enumerate(
                    zip(self._children, datasets.pile['collate']))
            ]

            return VirtualDatasetBox(
                xarray.concat([grouped.pile for grouped in groups],
                              dim='time'),
                select_unique([grouped.geobox for grouped in groups]),
                merge_dicts(
                    [grouped.product_definitions for grouped in groups]))

        elif 'juxtapose' in self:
            self._assert(
                'juxtapose' in datasets.pile
                and len(datasets.pile['juxtapose']) == len(self._children),
                "invalid dataset pile")

            groups = [
                product.group(
                    VirtualDatasetBag(dataset_pile, datasets.grid_spec,
                                      datasets.geopolygon,
                                      datasets.product_definitions),
                    **search_terms) for product, dataset_pile in zip(
                        self._children, datasets.pile['juxtapose'])
            ]

            aligned_piles = xarray.align(*[grouped.pile for grouped in groups])

            def tuplify(indexes, _):
                return {
                    'juxtapose':
                    [pile.sel(**indexes).item() for pile in aligned_piles]
                }

            return VirtualDatasetBox(
                xr_apply(aligned_piles[0], tuplify),
                select_unique([grouped.geobox for grouped in groups]),
                merge_dicts(
                    [grouped.product_definitions for grouped in groups]))

        else:
            raise VirtualProductException("virtual product was not validated")
Example #20
0
def dc_load(
    datasets: Sequence[Dataset],
    measurements: Optional[Union[str, Sequence[str]]] = None,
    geobox: Optional[GeoBox] = None,
    groupby: Optional[str] = None,
    resampling: Optional[Union[str, Dict[str, str]]] = None,
    skip_broken_datasets: bool = False,
    chunks: Optional[Dict[str, int]] = None,
    progress_cbk: Optional[Callable[[int, int], Any]] = None,
    fuse_func=None,
    **kw,
) -> xr.Dataset:
    assert len(datasets) > 0

    # dask_chunks is a backward-compatibility alias for chunks
    if chunks is None:
        chunks = kw.pop("dask_chunks", None)
    # group_by is a backward-compatibility alias for groupby
    if groupby is None:
        groupby = kw.pop("group_by", "time")
    # bands alias for measurements
    if measurements is None:
        measurements = kw.pop("bands", None)

    # extract all "output_geobox" inputs
    geo_keys = {
        k: kw.pop(k)
        for k in [
            "like",
            "geopolygon",
            "resolution",
            "output_crs",
            "crs",
            "align",
            "x",
            "y",
            "lat",
            "lon",
        ] if k in kw
    }

    ds = datasets[0]
    product = ds.type

    if geobox is None:
        geobox = output_geobox(
            grid_spec=product.grid_spec,
            load_hints=product.load_hints(),
            **geo_keys,
            datasets=datasets,
        )
    elif len(geo_keys):
        warn(f"Supplied 'geobox=' parameter aliases {list(geo_keys)} inputs")

    grouped = Datacube.group_datasets(datasets, groupby)
    mm = product.lookup_measurements(measurements)
    return Datacube.load_data(
        grouped,
        geobox,
        mm,
        resampling=resampling,
        fuse_func=fuse_func,
        dask_chunks=chunks,
        skip_broken_datasets=skip_broken_datasets,
        progress_cbk=progress_cbk,
        **kw,
    )
Example #21
0
def test_load_data_cbk(tmpdir):
    from datacube.api import TerminateCurrentLoad

    tmpdir = Path(str(tmpdir))

    spatial = dict(
        resolution=(15, -15),
        offset=(11230, 1381110),
    )

    nodata = -999
    aa = mk_test_image(96, 64, 'int16', nodata=nodata)

    bands = [
        SimpleNamespace(name=name, values=aa, nodata=nodata)
        for name in ['aa', 'bb']
    ]

    ds, gbox = gen_tiff_dataset(bands,
                                tmpdir,
                                prefix='ds1-',
                                timestamp='2018-07-19',
                                **spatial)
    assert ds.time is not None

    ds2, _ = gen_tiff_dataset(bands,
                              tmpdir,
                              prefix='ds2-',
                              timestamp='2018-07-19',
                              **spatial)
    assert ds.time is not None
    assert ds.time == ds2.time

    sources = Datacube.group_datasets([ds, ds2], 'time')
    progress_call_data = []

    def progress_cbk(n, nt):
        progress_call_data.append((n, nt))

    ds_data = Datacube.load_data(sources,
                                 gbox,
                                 ds.type.measurements,
                                 progress_cbk=progress_cbk)

    assert progress_call_data == [(1, 4), (2, 4), (3, 4), (4, 4)]
    np.testing.assert_array_equal(aa, ds_data.aa.values[0])
    np.testing.assert_array_equal(aa, ds_data.bb.values[0])

    def progress_cbk_fail_early(n, nt):
        progress_call_data.append((n, nt))
        raise TerminateCurrentLoad()

    def progress_cbk_fail_early2(n, nt):
        progress_call_data.append((n, nt))
        if n > 1:
            raise KeyboardInterrupt()

    progress_call_data = []
    ds_data = Datacube.load_data(sources,
                                 gbox,
                                 ds.type.measurements,
                                 progress_cbk=progress_cbk_fail_early)

    assert progress_call_data == [(1, 4)]
    assert ds_data.dc_partial_load is True
    np.testing.assert_array_equal(aa, ds_data.aa.values[0])
    np.testing.assert_array_equal(nodata, ds_data.bb.values[0])

    progress_call_data = []
    ds_data = Datacube.load_data(sources,
                                 gbox,
                                 ds.type.measurements,
                                 progress_cbk=progress_cbk_fail_early2)

    assert ds_data.dc_partial_load is True
    assert progress_call_data == [(1, 4), (2, 4)]
Example #22
0
    def group(self, datasets, **search_terms):
        # type: (QueryResult, Dict[str, Any]) -> DatasetPile
        """
        Datasets grouped by their timestamps.
        :param datasets: the `QueryResult` to fetch data from
        :param query: to specify a spatial sub-region
        """
        grid_spec = datasets.grid_spec

        if 'product' in self:
            # select only those inside the ROI
            # ROI could be smaller than the query for `query`
            spatial_query = reject_keys(search_terms, self._NON_SPATIAL_KEYS)
            selected = list(
                select_datasets_inside_polygon(
                    datasets.pile, query_geopolygon(**spatial_query)))

            # geobox
            merged = merge_search_terms(
                select_keys(self, self._NON_SPATIAL_KEYS),
                select_keys(spatial_query, self._NON_SPATIAL_KEYS))

            geobox = output_geobox(datasets=selected,
                                   grid_spec=grid_spec,
                                   **select_keys(merged, self._GEOBOX_KEYS),
                                   **spatial_query)

            # group by time
            group_query = query_group_by(
                **select_keys(merged, self._GROUPING_KEYS))

            def wrap(_, value):
                return QueryResult(value, grid_spec)

            # information needed for Datacube.load_data
            return DatasetPile(Datacube.group_datasets(selected, group_query),
                               geobox).map(wrap)

        elif 'transform' in self:
            return self._input.group(datasets, **search_terms)

        elif 'collate' in self:
            self._assert(
                len(datasets.pile) == len(self._children),
                "invalid dataset pile")

            def build(source_index, product, dataset_pile):
                grouped = product.group(dataset_pile, **search_terms)

                def tag(_, value):
                    in_position = [
                        value if i == source_index else None
                        for i, _ in enumerate(datasets.pile)
                    ]
                    return QueryResult(in_position, grid_spec)

                return grouped.map(tag)

            groups = [
                build(source_index, product, dataset_pile)
                for source_index, (product, dataset_pile) in enumerate(
                    zip(self._children, datasets.pile))
            ]

            return DatasetPile(
                xarray.concat([grouped.pile for grouped in groups],
                              dim='time'),
                select_unique([grouped.geobox for grouped in groups]))

        elif 'juxtapose' in self:
            self._assert(
                len(datasets.pile) == len(self._children),
                "invalid dataset pile")

            groups = [
                product.group(datasets, **search_terms)
                for product, datasets in zip(self._children, datasets.pile)
            ]

            aligned_piles = xarray.align(*[grouped.pile for grouped in groups])
            child_groups = [
                DatasetPile(aligned_piles[i], grouped.geobox)
                for i, grouped in enumerate(groups)
            ]

            def tuplify(indexes, _):
                return QueryResult([
                    grouped.pile.sel(**indexes).item()
                    for grouped in child_groups
                ], grid_spec)

            return DatasetPile(
                child_groups[0].map(tuplify).pile,
                select_unique([grouped.geobox for grouped in groups]))

        else:
            raise VirtualProductException("virtual product was not validated")