def make_ndvi_tasks(index, config, year=None, **kwargs): input_type = config['nbar_dataset_type'] output_type = config['ndvi_dataset_type'] workflow = GridWorkflow(index, output_type.grid_spec) # TODO: Filter query to valid options query = {} if year is not None: if isinstance(year, integer_types): query['time'] = Range(datetime(year=year, month=1, day=1), datetime(year=year+1, month=1, day=1)) elif isinstance(year, tuple): query['time'] = Range(datetime(year=year[0], month=1, day=1), datetime(year=year[1]+1, month=1, day=1)) tiles_in = workflow.list_tiles(product=input_type.name, **query) tiles_out = workflow.list_tiles(product=output_type.name, **query) def make_task(tile, **task_kwargs): task = dict(nbar=workflow.update_tile_lineage(tile)) task.update(task_kwargs) return task tasks = (make_task(tile, tile_index=key, filename=get_filename(config, tile_index=key, sources=tile.sources)) for key, tile in tiles_in.items() if key not in tiles_out) return tasks
def find_diff(input_type, output_type, index, ingestion_bounds=None, **query): from datacube.api.grid_workflow import GridWorkflow workflow = GridWorkflow(index, output_type.grid_spec) tiles_in = workflow.list_tiles(product=input_type.name, **query) tiles_out = workflow.list_tiles(product=output_type.name, **query) def update_dict(d, **kwargs): result = d.copy() result.update(kwargs) if ingestion_bounds is not None: polygon = d['geobox'].geographic_extent.points # Need to programatically figure out the upper left and lower right. #polygon[0] = UL, polygon[2] = LR # http://www.geeksforgeeks.org/find-two-rectangles-overlap/ # if top left x1 > bottom right x2 or top left x2 > bottom right x1 # if top left y1 < bottom right y2 or top left y2 < bottom right y1 if polygon[0][0] > ingestion_bounds['right'] or ingestion_bounds[ 'left'] > polygon[2][0]: return None if polygon[0][1] < ingestion_bounds['bottom'] or ingestion_bounds[ 'top'] < polygon[2][1]: return None return result return result tasks = [ update_dict(tile, index=key) for key, tile in tiles_in.items() if key not in tiles_out ] tasks = list(filter(None, tasks)) return tasks
def do_stats(task, config): source = task['source'] measurement_name = source['measurements'][0] var_params = get_variable_params(config) results = create_output_files(config['stats'], config['location'], measurement_name, task, var_params) for tile_index in tile_iter(task['data'], {'x': 1000, 'y': 1000}): data = GridWorkflow.load(slice_tile(task['data'], tile_index), measurements=[measurement_name])[measurement_name] data = data.where(data != data.attrs['nodata']) for spec, sources in zip(source['masks'], task['masks']): mask = GridWorkflow.load(slice_tile(sources, tile_index), measurements=[spec['measurement']])[spec['measurement']] mask = make_mask(mask, **spec['flags']) data = data.where(mask) del mask for stat in config['stats']: data_stats = getattr(data, stat['name'])(dim='time') results[stat['name']][measurement_name][tile_index][0] = data_stats print(data_stats) for nco in config['stats'].values: nco.close()
def make_ndvi_tasks(index, config, year=None, **kwargs): input_type = config['nbar_dataset_type'] output_type = config['ndvi_dataset_type'] workflow = GridWorkflow(index, output_type.grid_spec) # TODO: Filter query to valid options query = {} if year is not None: if isinstance(year, integer_types): query['time'] = Range(datetime(year=year, month=1, day=1), datetime(year=year + 1, month=1, day=1)) elif isinstance(year, tuple): query['time'] = Range(datetime(year=year[0], month=1, day=1), datetime(year=year[1] + 1, month=1, day=1)) tiles_in = workflow.list_tiles(product=input_type.name, **query) tiles_out = workflow.list_tiles(product=output_type.name, **query) def make_task(tile, **task_kwargs): task = dict(nbar=workflow.update_tile_lineage(tile)) task.update(task_kwargs) return task tasks = (make_task(tile, tile_index=key, filename=get_filename(config, tile_index=key, sources=tile.sources)) for key, tile in tiles_in.items() if key not in tiles_out) return tasks
def find_diff(input_type, output_type, index, ingestion_bounds=None, **query): from datacube.api.grid_workflow import GridWorkflow workflow = GridWorkflow(index, output_type.grid_spec) tiles_in = workflow.list_tiles(product=input_type.name, **query) tiles_out = workflow.list_tiles(product=output_type.name, **query) def update_dict(d, **kwargs): result = d.copy() result.update(kwargs) if ingestion_bounds is not None: polygon = d['geobox'].geographic_extent.points # Need to programatically figure out the upper left and lower right. #polygon[0] = UL, polygon[2] = LR # http://www.geeksforgeeks.org/find-two-rectangles-overlap/ # if top left x1 > bottom right x2 or top left x2 > bottom right x1 # if top left y1 < bottom right y2 or top left y2 < bottom right y1 if polygon[0][0] > ingestion_bounds['right'] or ingestion_bounds['left'] > polygon[2][0]: return None if polygon[0][1] < ingestion_bounds['bottom'] or ingestion_bounds['top'] < polygon[2][1]: return None return result return result tasks = [update_dict(tile, index=key) for key, tile in tiles_in.items() if key not in tiles_out] tasks = list(filter(None, tasks)) return tasks
def do_stats(task, config): source = task['source'] measurement_name = source['measurements'][0] var_params = get_variable_params(config) results = create_output_files(config['stats'], config['location'], measurement_name, task, var_params) for tile_index in tile_iter(task['data'], {'x': 1000, 'y': 1000}): data = GridWorkflow.load(slice_tile(task['data'], tile_index), measurements=[measurement_name ])[measurement_name] data = data.where(data != data.attrs['nodata']) for spec, sources in zip(source['masks'], task['masks']): mask = GridWorkflow.load(slice_tile(sources, tile_index), measurements=[spec['measurement'] ])[spec['measurement']] mask = make_mask(mask, **spec['flags']) data = data.where(mask) del mask for stat in config['stats']: data_stats = getattr(data, stat['name'])(dim='time') results[stat['name']][measurement_name][tile_index][0] = data_stats print(data_stats) for nco in config['stats'].values: nco.close()
def find_diff(input_type, output_type, index, **query): from datacube.api.grid_workflow import GridWorkflow workflow = GridWorkflow(index, output_type.grid_spec) tiles_in = workflow.list_tiles(product=input_type.name, **query) tiles_out = workflow.list_tiles(product=output_type.name, **query) tasks = [{'tile': tile, 'tile_index': key} for key, tile in tiles_in.items() if key not in tiles_out] return tasks
def test_gridworkflow_with_time_depth(): """Test GridWorkflow with time series. Also test `Tile` methods `split` and `split_by_time` """ from mock import MagicMock import datetime fakecrs = geometry.CRS('EPSG:4326') grid = 100 # spatial frequency in crs units pixel = 10 # square pixel linear dimension in crs units # if cell(0,0) has lower left corner at grid origin, # and cell indices increase toward upper right, # then this will be cell(1,-2). gridspec = GridSpec(crs=fakecrs, tile_size=(grid, grid), resolution=(-pixel, pixel)) # e.g. product gridspec def make_fake_datasets(num_datasets): start_time = datetime.datetime(2001, 2, 15) delta = datetime.timedelta(days=16) for i in range(num_datasets): fakedataset = MagicMock() fakedataset.extent = geometry.box(left=grid, bottom=-grid, right=2 * grid, top=-2 * grid, crs=fakecrs) fakedataset.center_time = start_time + (delta * i) yield fakedataset fakeindex = PickableMock() fakeindex.datasets.get_field_names.return_value = [ 'time' ] # permit query on time fakeindex.datasets.search_eager.return_value = list( make_fake_datasets(100)) # ------ test with time dimension ---- from datacube.api.grid_workflow import GridWorkflow gw = GridWorkflow(fakeindex, gridspec) query = dict(product='fake_product_name') cells = gw.list_cells(**query) for cell_index, cell in cells.items(): # test Tile.split() for label, tile in cell.split('time'): assert tile.shape == (1, 10, 10) # test Tile.split_by_time() for year, year_cell in cell.split_by_time(freq='A'): for t in year_cell.sources.time.values: assert str(t)[:4] == year
def check_open_with_grid_workflow(index): from datacube.api.core import Datacube dc = Datacube(index=index) type_name = 'ls5_nbar_albers' dt = dc.index.datasets.types.get_by_name(type_name) from datacube.api.grid_workflow import GridWorkflow gw = GridWorkflow(dc, dt.grid_spec) cells = gw.list_cells(product=type_name) assert LBG_CELL in cells tiles = gw.list_tiles(product=type_name) assert tiles assert tiles[LBG_CELL] ts, tile = tiles[LBG_CELL].popitem() dataset_cell = gw.load(LBG_CELL, tile, measurements=['blue']) assert dataset_cell['blue'].size dataset_cell = gw.load(LBG_CELL, tile) assert all(m in dataset_cell for m in ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']) tiles = gw.list_tile_stacks(product=type_name) assert tiles assert tiles[LBG_CELL] tile = tiles[LBG_CELL] dataset_cell = gw.load(LBG_CELL, tile, measurements=['blue']) assert dataset_cell['blue'].size dataset_cell = gw.load(LBG_CELL, tile) assert all(m in dataset_cell for m in ['blue', 'green', 'red', 'nir', 'swir1', 'swir2'])
def find_diff(input_type, output_type, driver_manager, time_size, **query): from datacube.api.grid_workflow import GridWorkflow workflow = GridWorkflow(None, output_type.grid_spec, driver_manager=driver_manager) cells_in = workflow.list_cells(product=input_type.name, **query) cells_out = workflow.list_cells(product=output_type.name, **query) remove_duplicates(cells_in, cells_out) tasks = [{'tile': cell, 'tile_index': extent} for extent, cell in cells_in.items()] new_tasks = [] for task in tasks: tiles = task['tile'].split('time', time_size) for t in tiles: new_tasks.append({'tile': t[1], 'tile_index': task['tile_index']}) return new_tasks
def check_open_with_grid_workflow(index): type_name = 'ls5_nbar_albers' dt = index.datasets.types.get_by_name(type_name) from datacube.api.grid_workflow import GridWorkflow gw = GridWorkflow(index, dt.grid_spec) cells = gw.list_cells(product=type_name) assert LBG_CELL in cells tile = cells[LBG_CELL] dataset_cell = gw.load(tile, measurements=['blue']) assert dataset_cell['blue'].size dataset_cell = gw.load(tile) assert all(m in dataset_cell for m in ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']) ts = numpy.datetime64('1992-03-23T23:14:25.500000000') tile_key = LBG_CELL + (ts,) tiles = gw.list_tiles(product=type_name) assert tiles assert tile_key in tiles tile = tiles[tile_key] dataset_cell = gw.load(tile, measurements=['blue']) assert dataset_cell['blue'].size dataset_cell = gw.load(tile) assert all(m in dataset_cell for m in ['blue', 'green', 'red', 'nir', 'swir1', 'swir2'])
def make_tasks(index, config): query = dict(time=(datetime(2011, 1, 1), datetime(2011, 2, 1))) workflow = GridWorkflow(index, grid_spec=get_grid_spec(config)) assert len(config['sources']) == 1 # TODO: merge multiple sources for source in config['sources']: data = workflow.list_cells(product=source['product'], cell_index=(15, -40), **query) masks = [workflow.list_cells(product=mask['product'], cell_index=(15, -40), **query) for mask in source['masks']] for key in data.keys(): yield { 'source': source, 'index': key, 'data': data[key], 'masks': [mask[key] for mask in masks] }
def check_open_with_grid_workflow(driver_manager): type_name = 'ls5_nbar_albers' dt = driver_manager.index.products.get_by_name(type_name) from datacube.api.grid_workflow import GridWorkflow gw = GridWorkflow(None, dt.grid_spec, driver_manager=driver_manager) cells = gw.list_cells(product=type_name, cell_index=LBG_CELL) assert LBG_CELL in cells cells = gw.list_cells(product=type_name) assert LBG_CELL in cells tile = cells[LBG_CELL] assert 'x' in tile.dims assert 'y' in tile.dims assert 'time' in tile.dims assert tile.shape[1] == 4000 assert tile.shape[2] == 4000 assert tile[:1, :100, :100].shape == (1, 100, 100) dataset_cell = gw.load(tile, measurements=['blue'], driver_manager=driver_manager) assert dataset_cell['blue'].shape == tile.shape for timestamp, tile_slice in tile.split('time'): assert tile_slice.shape == (1, 4000, 4000) dataset_cell = gw.load(tile, driver_manager=driver_manager) assert all(m in dataset_cell for m in ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']) ts = numpy.datetime64('1992-03-23T23:14:25.500000000') tile_key = LBG_CELL + (ts, ) tiles = gw.list_tiles(product=type_name) assert tiles assert tile_key in tiles tile = tiles[tile_key] dataset_cell = gw.load(tile, measurements=['blue'], driver_manager=driver_manager) assert dataset_cell['blue'].size dataset_cell = gw.load(tile, driver_manager=driver_manager) assert all(m in dataset_cell for m in ['blue', 'green', 'red', 'nir', 'swir1', 'swir2'])
def find_diff(input_type, output_type, bbox, datacube): from datacube.api.grid_workflow import GridWorkflow workflow = GridWorkflow(datacube, output_type.grid_spec) tiles_in = workflow.cell_observations(product=input_type.name) tiles_out = workflow.cell_observations(product=output_type.name) tasks = [] for tile_index in set(tiles_in.keys()) | set(tiles_out.keys()): sources_in = datacube.product_sources( tiles_in.get(tile_index, []), lambda ds: ds.center_time, 'time', 'seconds since 1970-01-01 00:00:00') sources_out = datacube.product_sources( tiles_out.get(tile_index, []), lambda ds: ds.center_time, 'time', 'seconds since 1970-01-01 00:00:00') diff = numpy.setdiff1d(sources_in.time.values, sources_out.time.values) tasks += [(tile_index, sources_in.sel(time=[v])) for v in diff] return tasks
def test_create_gridworkflow_with_logging(index): from logging import getLogger, StreamHandler logger = getLogger(__name__) handler = StreamHandler() logger.addHandler(handler) try: gw = GridWorkflow(index) finally: logger.removeHandler(handler)
def find_diff(input_type, output_type, bbox, datacube): from datacube.api.grid_workflow import GridWorkflow workflow = GridWorkflow(datacube, output_type.grid_spec) tiles_in = workflow.cell_observations(product=input_type.name) tiles_out = workflow.cell_observations(product=output_type.name) tasks = [] for tile_index in set(tiles_in.keys()) | set(tiles_out.keys()): sources_in = datacube.product_sources(tiles_in.get(tile_index, []), lambda ds: ds.center_time, 'time', 'seconds since 1970-01-01 00:00:00') sources_out = datacube.product_sources(tiles_out.get(tile_index, []), lambda ds: ds.center_time, 'time', 'seconds since 1970-01-01 00:00:00') diff = numpy.setdiff1d(sources_in.time.values, sources_out.time.values) tasks += [(tile_index, sources_in.sel(time=[v])) for v in diff] return tasks
def make_fc_tasks(index: Index, config: dict, query: dict, **kwargs): input_product = config['nbar_product'] output_product = config['fc_product'] workflow = GridWorkflow(index, output_product.grid_spec) tiles_in = workflow.list_tiles(product=input_product.name, **query) _LOG.info(f"{len(tiles_in)} {input_product.name} tiles in {repr(query)}") tiles_out = workflow.list_tiles(product=output_product.name, **query) _LOG.info(f"{len(tiles_out)} {output_product.name} tiles in {repr(query)}") return ( dict( nbar=workflow.update_tile_lineage(tile), tile_index=key, filename=get_filename(config, tile_index=key, sources=tile.sources) ) for key, tile in tiles_in.items() if key not in tiles_out )
def make_tasks(index, config): query = dict(time=(datetime(2011, 1, 1), datetime(2011, 2, 1))) workflow = GridWorkflow(index, grid_spec=get_grid_spec(config)) assert len(config['sources']) == 1 # TODO: merge multiple sources for source in config['sources']: data = workflow.list_cells(product=source['product'], cell_index=(15, -40), **query) masks = [ workflow.list_cells(product=mask['product'], cell_index=(15, -40), **query) for mask in source['masks'] ] for key in data.keys(): yield { 'source': source, 'index': key, 'data': data[key], 'masks': [mask[key] for mask in masks] }
def test_create_gridworkflow_init_failures(fake_index): index = fake_index # need product or grispec with pytest.raises(ValueError): GridWorkflow(index) # test missing product with pytest.raises(ValueError): GridWorkflow(index, product="no-such-product") # test missing product assert fake_index.products.get_by_name("without_gs") is not None assert fake_index.products.get_by_name("without_gs").grid_spec is None with pytest.raises(ValueError): GridWorkflow(index, product="without_gs") product = fake_index.products.get_by_name("with_gs") assert product is not None assert product.grid_spec is not None gw = GridWorkflow(index, product="with_gs") assert gw.grid_spec is product.grid_spec
def _make_fc_tasks(index: Index, config: dict, query: dict): """ Generate an iterable of 'tasks', matching the provided filter parameters. Tasks can be generated for: - all of time - 1 particular year - a range of years """ input_product = config['nbart_product'] output_product = config['fc_product'] workflow = GridWorkflow(index, output_product.grid_spec) tiles_in = workflow.list_tiles(product=input_product.name, **query) _LOG.info(f"{len(tiles_in)} {input_product.name} tiles in {repr(query)}") tiles_out = workflow.list_tiles(product=output_product.name, **query) _LOG.info(f"{len(tiles_out)} {output_product.name} tiles in {repr(query)}") return (dict(nbart=workflow.update_tile_lineage(tile), tile_index=key, filename=_get_filename(config, tile_index=key, sources=tile.sources)) for key, tile in tiles_in.items() if key not in tiles_out)
def do_ndvi_task(config, task): global_attributes = config['global_attributes'] variable_params = config['variable_params'] file_path = Path(task['filename']) output_type = config['ndvi_dataset_type'] measurement = output_type.measurements['ndvi'] output_dtype = np.dtype(measurement['dtype']) nodata_value = np.dtype(output_dtype).type(measurement['nodata']) if file_path.exists(): raise OSError(errno.EEXIST, 'Output file already exists', str(file_path)) measurements = ['red', 'nir'] nbar_tile = task['nbar'] nbar = GridWorkflow.load(nbar_tile, measurements) ndvi = calculate_ndvi(nbar, nodata=nodata_value, dtype=output_dtype, units=measurement['units']) def _make_dataset(labels, sources): assert len(sources) geobox = nbar.geobox source_data = union_points( *[dataset.extent.to_crs(geobox.crs).points for dataset in sources]) valid_data = intersect_points(geobox.extent.points, source_data) dataset = make_dataset(product=output_type, sources=sources, extent=geobox.extent, center_time=labels['time'], uri=file_path.absolute().as_uri(), app_info=get_app_metadata(config), valid_data=GeoPolygon(valid_data, geobox.crs)) return dataset datasets = xr_apply(nbar_tile.sources, _make_dataset, dtype='O') ndvi['dataset'] = datasets_to_doc(datasets) write_dataset_to_netcdf( dataset=ndvi, filename=Path(file_path), global_attributes=global_attributes, variable_params=variable_params, ) return datasets
def _do_fc_task(config, task): """ Load data, run FC algorithm, attach metadata, and write output. :param dict config: Config object :param dict task: Dictionary of tasks :return: Dataset objects representing the generated data that can be added to the index :rtype: list(datacube.model.Dataset) """ global_attributes = config['global_attributes'] variable_params = config['variable_params'] file_path = Path(task['filename']) output_product = config['fc_product'] if file_path.exists(): raise OSError(errno.EEXIST, 'Output file already exists', str(file_path)) nbart_tile: Tile = task['nbart'] nbart = GridWorkflow.load(nbart_tile, ['green', 'red', 'nir', 'swir1', 'swir2']) output_measurements = config['fc_product'].measurements.values() fc_dataset = _make_fc_tile(nbart, output_measurements, config.get('sensor_regression_coefficients')) def _make_dataset(labels, sources): assert sources dataset = make_dataset(product=output_product, sources=sources, extent=nbart.geobox.extent, center_time=labels['time'], uri=file_path.absolute().as_uri(), app_info=_get_app_metadata(config), valid_data=polygon_from_sources_extents( sources, nbart.geobox)) return dataset datasets = xr_apply(nbart_tile.sources, _make_dataset, dtype='O') fc_dataset['dataset'] = datasets_to_doc(datasets) write_dataset_to_netcdf( dataset=fc_dataset, filename=file_path, global_attributes=global_attributes, variable_params=variable_params, ) return datasets
def do_ndvi_task(config, task): global_attributes = config['global_attributes'] variable_params = config['variable_params'] file_path = Path(task['filename']) output_type = config['ndvi_dataset_type'] measurement = output_type.measurements['ndvi'] output_dtype = np.dtype(measurement['dtype']) nodata_value = np.dtype(output_dtype).type(measurement['nodata']) if file_path.exists(): raise OSError(errno.EEXIST, 'Output file already exists', str(file_path)) measurements = ['red', 'nir'] nbar_tile = task['nbar'] nbar = GridWorkflow.load(nbar_tile, measurements) ndvi = calculate_ndvi(nbar, nodata=nodata_value, dtype=output_dtype, units=measurement['units']) def _make_dataset(labels, sources): assert len(sources) geobox = nbar.geobox source_data = union_points(*[dataset.extent.to_crs(geobox.crs).points for dataset in sources]) valid_data = intersect_points(geobox.extent.points, source_data) dataset = make_dataset(product=output_type, sources=sources, extent=geobox.extent, center_time=labels['time'], uri=file_path.absolute().as_uri(), app_info=get_app_metadata(config), valid_data=GeoPolygon(valid_data, geobox.crs)) return dataset datasets = xr_apply(nbar_tile.sources, _make_dataset, dtype='O') ndvi['dataset'] = datasets_to_doc(datasets) write_dataset_to_netcdf( dataset=ndvi, filename=Path(file_path), global_attributes=global_attributes, variable_params=variable_params, ) return datasets
def do_fc_task(config, task): global_attributes = config['global_attributes'] variable_params = config['variable_params'] file_path = Path(task['filename']) output_product = config['fc_product'] if file_path.exists(): raise OSError(errno.EEXIST, 'Output file already exists', str(file_path)) nbar_tile: Tile = task['nbar'] nbar = GridWorkflow.load(nbar_tile, ['green', 'red', 'nir', 'swir1', 'swir2']) output_measurements = config['fc_product'].measurements.values() fc_dataset = make_fc_tile(nbar, output_measurements, config.get('sensor_regression_coefficients')) def _make_dataset(labels, sources): assert sources dataset = make_dataset(product=output_product, sources=sources, extent=nbar.geobox.extent, center_time=labels['time'], uri=file_path.absolute().as_uri(), app_info=get_app_metadata(config), valid_data=GeoPolygon.from_sources_extents(sources, nbar.geobox)) return dataset datasets = xr_apply(nbar_tile.sources, _make_dataset, dtype='O') fc_dataset['dataset'] = datasets_to_doc(datasets) write_dataset_to_netcdf( dataset=fc_dataset, filename=file_path, global_attributes=global_attributes, variable_params=variable_params, ) return datasets
def test_gridworkflow(): """ Test GridWorkflow with padding option. """ from mock import MagicMock import datetime # ----- fake a datacube ----- # e.g. let there be a dataset that coincides with a grid cell fakecrs = geometry.CRS('EPSG:4326') grid = 100 # spatial frequency in crs units pixel = 10 # square pixel linear dimension in crs units # if cell(0,0) has lower left corner at grid origin, # and cell indices increase toward upper right, # then this will be cell(1,-2). gridspec = GridSpec(crs=fakecrs, tile_size=(grid, grid), resolution=(-pixel, pixel)) # e.g. product gridspec fakedataset = MagicMock() fakedataset.extent = geometry.box(left=grid, bottom=-grid, right=2 * grid, top=-2 * grid, crs=fakecrs) fakedataset.center_time = t = datetime.datetime(2001, 2, 15) fakedataset.id = uuid.uuid4() fakeindex = PickableMock() fakeindex._db = None fakeindex.datasets.get_field_names.return_value = [ 'time' ] # permit query on time fakeindex.datasets.search_eager.return_value = [fakedataset] # ------ test without padding ---- from datacube.api.grid_workflow import GridWorkflow gw = GridWorkflow(fakeindex, gridspec) # Need to force the fake index otherwise the driver manager will # only take its _db gw.index = fakeindex query = dict(product='fake_product_name', time=('2001-1-1 00:00:00', '2001-3-31 23:59:59')) # test backend : that it finds the expected cell/dataset assert list(gw.cell_observations(**query).keys()) == [(1, -2)] # again but with geopolygon assert list( gw.cell_observations(**query, geopolygon=gridspec.tile_geobox( (1, -2)).extent).keys()) == [(1, -2)] with pytest.raises(ValueError) as e: list( gw.cell_observations(**query, tile_buffer=(1, 1), geopolygon=gridspec.tile_geobox( (1, -2)).extent).keys()) assert str( e.value) == 'Cannot process tile_buffering and geopolygon together.' # test frontend assert len(gw.list_tiles(**query)) == 1 # ------ introduce padding -------- assert len(gw.list_tiles(tile_buffer=(20, 20), **query)) == 9 # ------ add another dataset (to test grouping) ----- # consider cell (2,-2) fakedataset2 = MagicMock() fakedataset2.extent = geometry.box(left=2 * grid, bottom=-grid, right=3 * grid, top=-2 * grid, crs=fakecrs) fakedataset2.center_time = t fakedataset2.id = uuid.uuid4() def search_eager(lat=None, lon=None, **kwargs): return [fakedataset, fakedataset2] fakeindex.datasets.search_eager = search_eager # unpadded assert len(gw.list_tiles(**query)) == 2 ti = numpy.datetime64(t, 'ns') assert set(gw.list_tiles(**query).keys()) == {(1, -2, ti), (2, -2, ti)} # padded assert len(gw.list_tiles(tile_buffer=(20, 20), ** query)) == 12 # not 18=2*9 because of grouping # -------- inspect particular returned tile objects -------- # check the array shape tile = gw.list_tiles(**query)[1, -2, ti] # unpadded example assert grid / pixel == 10 assert tile.shape == (1, 10, 10) padded_tile = gw.list_tiles(tile_buffer=(20, 20), **query)[1, -2, ti] # padded example # assert grid/pixel + 2*gw2.grid_spec.padding == 14 # GREG: understand this assert padded_tile.shape == (1, 14, 14) # count the sources assert len(tile.sources.isel(time=0).item()) == 1 assert len(padded_tile.sources.isel(time=0).item()) == 2 # check the geocoding assert tile.geobox.alignment == padded_tile.geobox.alignment assert tile.geobox.affine * (0, 0) == padded_tile.geobox.affine * (2, 2) assert tile.geobox.affine * (10, 10) == padded_tile.geobox.affine * ( 10 + 2, 10 + 2) # ------- check loading -------- # GridWorkflow accesses the load_data API # to ultimately convert geobox,sources,measurements to xarray, # so only thing to check here is the call interface. measurement = dict(nodata=0, dtype=numpy.int) fakedataset.type.lookup_measurements.return_value = {'dummy': measurement} fakedataset2.type = fakedataset.type from mock import patch with patch('datacube.api.core.Datacube.load_data') as loader: data = GridWorkflow.load(tile) data2 = GridWorkflow.load(padded_tile) # Note, could also test Datacube.load for consistency (but may require more patching) assert data is data2 is loader.return_value assert loader.call_count == 2 # Note, use of positional arguments here is not robust, could spec mock etc. for (args, kwargs), loadable in zip(loader.call_args_list, [tile, padded_tile]): args = list(args) assert args[0] is loadable.sources assert args[1] is loadable.geobox assert list(args[2].values())[0] is measurement assert 'resampling' in kwargs # ------- check single cell index extract ------- tile = gw.list_tiles(cell_index=(1, -2), **query) assert len(tile) == 1 assert tile[1, -2, ti].shape == (1, 10, 10) assert len(tile[1, -2, ti].sources.values[0]) == 1 padded_tile = gw.list_tiles(cell_index=(1, -2), tile_buffer=(20, 20), **query) assert len(padded_tile) == 1 assert padded_tile[1, -2, ti].shape == (1, 14, 14) assert len(padded_tile[1, -2, ti].sources.values[0]) == 2