示例#1
0
文件: ccdc.py 项目: klh5/CCDC
def loadByTile(products, key, min_y, max_y, min_x, max_x, bands):

    ds = []

    for product in products:

        dc = datacube.Datacube()

        # Create the GridWorkflow object for this product
        curr_gw = GridWorkflow(dc.index, product=product)

        # Get the list of tiles (one for each time point) for this product
        tile_list = curr_gw.list_tiles(product=product,
                                       cell_index=key,
                                       group_by='solar_day')

        dc.close()

        # Retrieve the specified pixel for each tile in the list
        for tile_index, tile in tile_list.items():
            dataset = curr_gw.load(tile[0:1, min_y:max_y, min_x:max_x],
                                   measurements=bands)

            if (dataset.variables):
                ds.append(dataset)

    return ds
示例#2
0
    def load_tile_data(self, factors):
        """
        Load and return factor data for confidence band prediction.
        :param factors: List of factor info as given by Config
        """

        model_data = []
        for fac in factors:
            factor = self.cfg.get_factor_info(fac)
            with Datacube(app='confidence_layer', env=factor['env']) as dc:
                gwf = GridWorkflow(dc.index, self.grid_spec)
                indexed_tiles = gwf.list_cells(self.tile_index,
                                               product=factor['product'])
                # load the data of the tile
                dataset = gwf.load(tile=indexed_tiles[self.tile_index],
                                   measurements=[factor['band']])
                data = dataset.data_vars[factor['band']].data

            # Rescale where needed: Keep an eye on this since this is to do with different scaling factors used during
            # training than what is on datacube
            if factor['name'].startswith('phat'): data = data * 100.0

            if factor['name'].startswith('phat'): data[data < 0.0] = 0.0
            if factor['name'].startswith('mrvbf'): data[data > 10] = 10
            if factor['name'].startswith('modis'): data[data > 100] = 100
            model_data.append(data.ravel())
            del data
        return np.column_stack(model_data)
示例#3
0
    def load_slice(i):
        loc = [slice(i, i + 1), slice(None), slice(None)]
        d = GridWorkflow.load(tile[loc], **kwargs)

        if mask_nodata:
            d = sensible_mask_invalid_data(d)

        # Load all masks and combine them all into one
        mask = None
        for m_tile, flags, load_args in masks:
            m = GridWorkflow.load(m_tile[loc], **load_args)
            m, *other = m.data_vars.values()
            m = make_mask(m, **flags)

            if mask is None:
                mask = m
            else:
                mask &= m

        if mask is not None:
            # Apply mask in place if asked or if we already performed
            # conversion to float32, this avoids reallocation of memory and
            # hence increases the largest data set size one can load without
            # running out of memory
            if mask_inplace or mask_nodata:
                d = sensible_where_inplace(d, mask)
            else:
                d = sensible_where(d, mask)

        if src_idx is not None:
            d.coords['source'] = ('time', np.repeat(src_idx, d.time.size))

        return d
示例#4
0
    def compute_confidence_filtered(self):
        """
        Return the wofs filtered summary band data that is 10% filtered by confidence band.
        """

        con_layer = self.compute_confidence()
        env = self.cfg.get_env_of_product('wofs_summary')

        with Datacube(app='wofs_summary', env=env) as dc:
            gwf = GridWorkflow(dc.index, self.grid_spec)
            indexed_tile = gwf.list_cells(self.tile_index,
                                          product='wofs_summary')
            # load the data of the tile
            dataset = gwf.load(tile=indexed_tile[self.tile_index],
                               measurements=['frequency'])
            data = dataset.data_vars['frequency'].data.ravel().reshape(
                self.grid_spec.tile_resolution)

        con_filtering = self.cfg.cfg.get('confidence_filtering')
        threshold = None
        if con_filtering:
            threshold = con_filtering.get('threshold')

        if threshold:
            data[con_layer <= threshold] = DEFAULT_FLOAT_NODATA
        else:
            data[con_layer <= 0.10] = DEFAULT_FLOAT_NODATA

        return data
示例#5
0
    def get_dataset_tiles(self,
                          product,
                          product_type=None,
                          platform=None,
                          time=None,
                          longitude=None,
                          latitude=None,
                          measurements=None,
                          output_crs=None,
                          resolution=None,
                          **kwargs):
        """
        Gets and returns data based on lat/long bounding box inputs.
        All params are optional. Leaving one out will just query the dc without it, (eg leaving out
        lat/lng but giving product returns dataset containing entire product.)

        Args:
            product (string): The name of the product associated with the desired dataset.
            product_type (string): The type of product associated with the desired dataset.
            platform (string): The platform associated with the desired dataset.
            time (tuple): A tuple consisting of the start time and end time for the dataset.
            longitude (tuple): A tuple of floats specifying the min,max longitude bounds.
            latitude (tuple): A tuple of floats specifying the min,max latitutde bounds.
            measurements (list): A list of strings that represents all measurements.
            output_crs (string): Determines reprojection of the data before its returned
            resolution (tuple): A tuple of min,max ints to determine the resolution of the data.

        Returns:
            data (xarray): dataset with the desired data in tiled sections.
        """

        # there is probably a better way to do this but I'm not aware of it.
        query = {}
        if product_type is not None:
            query['product_type'] = product_type
        if platform is not None:
            query['platform'] = platform
        if time is not None:
            query['time'] = time
        if longitude is not None and latitude is not None:
            query['longitude'] = longitude
            query['latitude'] = latitude

        # set up the grid workflow
        gw = GridWorkflow(self.dc.index, product=product)

        # dict of tiles.
        request_tiles = gw.list_cells(product=product,
                                      measurements=measurements,
                                      output_crs=output_crs,
                                      resolution=resolution,
                                      **query)

        # cells now return stacked xarrays of data.
        data_tiles = {}
        for tile_key in request_tiles:
            tile = request_tiles[tile_key]
            data_tiles[tile_key] = gw.load(tile, measurements=measurements)

        return data_tiles
示例#6
0
文件: ccdc.py 项目: klh5/CCDC
def loadAll(products, key, bands):

    ds = []

    for product in products:

        dc = datacube.Datacube()

        gw = GridWorkflow(dc.index, product=product)

        # Get the list of tiles (one for each time point) for this product
        tile_list = gw.list_tiles(product=product,
                                  cell_index=key,
                                  group_by='solar_day')

        dc.close()

        # Load all tiles
        for tile_index, tile in tile_list.items():
            dataset = gw.load(tile, measurements=bands)

            if (dataset.variables):
                ds.append(dataset)

    return ds
def test_wofs_filtered():
    cfg = Config('../configs/template_client.yaml')
    grid_spec = GridSpec(crs=CRS('EPSG:3577'),
                         tile_size=(100000, 100000),
                         resolution=(-25, 25))
    cell_index = (17, -39)
    wf = WofsFiltered(cfg, grid_spec, cell_index)
    confidence = wf.compute_confidence(cell_index)
    filtered = wf.compute_confidence_filtered()

    # Display images: to be removed later
    with Datacube(app='wofs_summary', env='dev') as dc:
        gwf = GridWorkflow(dc.index, grid_spec)
        indexed_tile = gwf.list_cells(cell_index,
                                      product='wofs_statistical_summary')
        # load the data of the tile
        dataset = gwf.load(tile=indexed_tile[cell_index],
                           measurements=['frequency'])
        frequency = dataset.data_vars['frequency'].data.ravel().reshape(
            grid_spec.tile_resolution)

    # Check with previous run
    with rasterio.open('confidenceFilteredWOfS_17_-39_epsilon=10.tiff') as f:
        data = f.read(1)
    plt.subplot(221)
    plt.imshow(frequency)
    plt.subplot(222)
    plt.imshow(data)
    plt.subplot(223)
    plt.imshow(confidence)
    plt.subplot(224)
    plt.imshow(filtered)
    plt.show()
    wf.compute_and_write()
示例#8
0
def save_grid_count_to_file(filename, index, **queryargs):
    gw = GridWorkflow(product=queryargs['product'], index=index)

    cells = gw.list_cells(group_by='solar_day', **queryargs)

    geojson = cells_list_to_featurecollection(cells)

    with open(filename, 'w') as dest:
        json.dump(geojson, dest)
示例#9
0
def load_masked_data(sub_tile_slice: Tuple[slice, slice, slice],
                     source_prod: DataSource,
                     geom=None) -> xarray.Dataset:
    data_fuse_func = import_function(
        source_prod.spec['fuse_func']
    ) if 'fuse_func' in source_prod.spec else None
    data = GridWorkflow.load(source_prod.data[sub_tile_slice],
                             measurements=source_prod.spec.get('measurements'),
                             fuse_func=data_fuse_func,
                             skip_broken_datasets=True)

    mask_inplace = source_prod.spec.get('mask_inplace', False)
    mask_nodata = source_prod.spec.get('mask_nodata', True)

    if mask_nodata:
        data = sensible_mask_invalid_data(data)

    # if all NaN
    completely_empty = all(
        ds for ds in xarray.ufuncs.isnan(data).all().data_vars.values())
    if completely_empty:
        # Discard empty slice
        return None

    if mask_inplace or not mask_nodata:
        where = sensible_where_inplace
    else:
        where = sensible_where

    if 'masks' in source_prod.spec:
        for mask_spec, mask_tile in zip(source_prod.spec['masks'],
                                        source_prod.masks):
            if mask_tile is None:
                # Discard data due to no mask data
                return None
            mask_fuse_func = import_function(
                mask_spec['fuse_func']) if 'fuse_func' in mask_spec else None
            mask = GridWorkflow.load(
                mask_tile[sub_tile_slice],
                measurements=[mask_spec['measurement']],
                fuse_func=mask_fuse_func,
                skip_broken_datasets=True)[mask_spec['measurement']]

            data = where(data, make_mask_from_spec(mask, mask_spec))
            del mask

    if geom is not None:
        data = where(data, geometry_mask([geom], data.geobox, invert=True))

    if source_prod.source_index is not None:
        data.coords['source'] = ('time',
                                 np.repeat(source_prod.source_index,
                                           data.time.size))

    return data
示例#10
0
 def _get_factor_datasets(self):
     dts = []
     for fac in self.confidence_model.factors:
         factor = self.cfg.get_factor_info(fac)
         with Datacube(app='confidence_layer', env=factor['env']) as dc:
             gwf = GridWorkflow(dc.index, self.grid_spec)
             obs = gwf.cell_observations(cell_index=self.tile_index,
                                         product=factor['product'])
             for ds in obs[self.tile_index]['datasets']:
                 dts.append(ds)
     return dts
def build_my_dataset(dt1, dt2, products, period, odir, cell=None, **prod):
    ls_7 = defaultdict()
    ls_8 = defaultdict()
    print(" building dataset for my cell ", cell)
    # pq = GridWorkflow.list_cells(eval(cl), product=prod, time=(dt2, dt1), group_by='solar_day')
    print(" loading data at ", str(datetime.now().time()))
    if prod['pq7']:
        data = GridWorkflow.load(
            prod['ls7'],
            measurements=['blue', 'green', 'red', 'nir', 'swir1', 'swir2'])
        print(" loaded nbar data for LS7", str(datetime.now().time()))
        pq = GridWorkflow.load(prod['pq7'], fuse_func=pq_fuser)
        print(" loaded pq data for sensor LS7 ", str(datetime.now().time()))
        mask_clear = pq['pixelquality'] & 15871 == 15871
        ndata = data.where(mask_clear).astype(np.int16)
        # sort in such that latest date comes first
        ndata = ndata.sel(time=sorted(ndata.time.values, reverse=True))
        if len(ndata.attrs) == 0:
            ndata.attrs = data.attrs
        ls_7[cell] = copy(ndata)
    if prod['pq8']:
        data = GridWorkflow.load(
            prod['ls8'],
            measurements=['blue', 'green', 'red', 'nir', 'swir1', 'swir2'])
        print(" loaded nbar data for LS8", str(datetime.now().time()))
        pq = GridWorkflow.load(prod['pq8'], fuse_func=pq_fuser)
        print(" loaded pq data for LS8 ", str(datetime.now().time()))
        mask_clear = pq['pixelquality'] & 15871 == 15871
        ndata = data.where(mask_clear).astype(np.int16)
        # sort in such that latest date comes first
        ndata = ndata.sel(time=sorted(ndata.time.values, reverse=True))
        if len(ndata.attrs) == 0:
            ndata.attrs = data.attrs
        ls_8[cell] = copy(ndata)

    my_set = set()
    for k, v in ls_8.items():
        my_set.add(k)
    for k, v in ls_7.items():
        my_set.add(k)
    ls_new = {}
    for k in list(my_set):
        if k in ls_8 and k in ls_7:
            ls_new[k] = xr.concat([ls_8[k], ls_7[k]], dim='time')
            ls_new[k] = ls_new[k].sel(
                time=sorted(ls_new[k].time.values, reverse=True))
        elif k in ls_7:
            ls_new[k] = ls_7[k]
        elif k in ls_8:
            ls_new[k] = ls_8[k]
    return ls_new
示例#12
0
def load_masked_data(sub_tile_slice: Tuple[slice, slice, slice],
                     source_prod: DataSource) -> xarray.Dataset:
    data_fuse_func = import_function(source_prod.spec['fuse_func']) if 'fuse_func' in source_prod.spec else None
    data = GridWorkflow.load(source_prod.data[sub_tile_slice],
                             measurements=source_prod.spec.get('measurements'),
                             fuse_func=data_fuse_func,
                             skip_broken_datasets=True)

    mask_inplace = source_prod.spec.get('mask_inplace', False)
    mask_nodata = source_prod.spec.get('mask_nodata', True)

    if mask_nodata:
        data = sensible_mask_invalid_data(data)

    # if all NaN
    completely_empty = all(ds for ds in xarray.ufuncs.isnan(data).all().data_vars.values())
    if completely_empty:
        # Discard empty slice
        return None

    if 'masks' in source_prod.spec:
        for mask_spec, mask_tile in zip(source_prod.spec['masks'], source_prod.masks):
            if mask_tile is None:
                # Discard data due to no mask data
                return None
            mask_fuse_func = import_function(mask_spec['fuse_func']) if 'fuse_func' in mask_spec else None
            mask = GridWorkflow.load(mask_tile[sub_tile_slice],
                                     measurements=[mask_spec['measurement']],
                                     fuse_func=mask_fuse_func,
                                     skip_broken_datasets=True)[mask_spec['measurement']]
            if mask_spec.get('flags') is not None:
                mask = make_mask(mask, **mask_spec['flags'])
            elif mask_spec.get('less_than') is not None:
                less_than = float(mask_spec['less_than'])
                mask = mask < less_than
            elif mask_spec.get('greater_than') is not None:
                greater_than = float(mask_spec['greater_than'])
                mask = mask > greater_than

            if mask_inplace:
                data = sensible_where_inplace(data, mask)
            else:
                data = sensible_where(data, mask)
            del mask

    if source_prod.source_index is not None:
        data.coords['source'] = ('time', np.repeat(source_prod.source_index, data.time.size))

    return data
示例#13
0
def extract_tile_db(tile, sp, training_set, sample):
    """Function to extract data under training geometries for a given tile

    Meant to be called within a dask.distributed.Cluster.map() over a list of tiles
    returned by GridWorkflow.list_cells
    Called in model_fit command line

    Args:
        tile: Datacube tile as returned by GridWorkflow.list_cells()
        sp: Spatial aggregation function
        training_set (str): Training data identifier (training_set field)
        sample (float): Proportion of training data to sample from the complete set

    Returns:
        A list of predictors and target values arrays
    """
    try:
        # Load tile as Dataset
        xr_dataset = GridWorkflow.load(tile[1])
        # Query the training geometries fitting into the extent of xr_dataset
        db = VectorDb()
        fc = db.load_training_from_dataset(xr_dataset,
                                           training_set=training_set,
                                           sample=sample)
        # fc is a feature collection with one property (class)
        # Overlay geometries and xr_dataset and perform extraction combined with spatial aggregation
        extract = zonal_stats_xarray(xr_dataset, fc, field='class', aggregation=sp)
        fc = None
        gc.collect()
        # Return the extracted array (or a list of two arrays?)
        return extract
    except Exception as e:
        return [None, None]
示例#14
0
def predict_object(tile, model_name, segmentation_name,
                   categorical_variables, aggregation, name):
    """Run a trained classifier in prediction mode on all objects intersection with a tile

    Args:
        tile: Datacube tile as returned by GridWorkflow.list_cells()
        model_name (str): Name under which the trained model is referenced in the
            database
        segmentation_name (str): Name of the segmentation to use
        categorical_variables (list): List of strings corresponding to categorical
            features.
        aggregation (str): Spatial aggregation method to use
    """
    try:
        # Load geoarray and feature collection
        geoarray = GridWorkflow.load(tile[1])
        fc = load_segmentation_from_dataset(geoarray, segmentation_name)
        # Extract array of features
        X, y = zonal_stats_xarray(dataset=geoarray, fc=fc, field='id',
                                  categorical_variables=categorical_variables,
                                  aggregation=aggregation)
        # Deallocate geoarray and feature collection
        geoarray = None
        fc = None
        gc.collect()
        # Load model
        PredModel = BaseModel.from_db(model_name)
        model_id = Model.objects.get(name=model_name).id
        try:
            # Avoid opening several threads in each process
            PredModel.model.n_jobs = 1
        except Exception as e:
            pass
        # Run prediction
        y_pred = PredModel.predict(X)
        y_conf = PredModel.predict_confidence(X)
        # Deallocate arrays of extracted values and model
        X = None
        PredModel = None
        gc.collect()
        # Build list of PredictClassification objects
        def predict_object_builder(i, pred, conf):
            return PredictClassification(model_id=model_id, predict_object_id=i,
                                         tag_id=pred, confidence=conf, name=name)
        # Write labels to database combining chunking and bulk_create
        for sub_zip in chunk(zip(y, y_pred, y_conf), 10000):
            obj_list = [predict_object_builder(i,pred,conf) for i, pred, conf in
                        sub_zip]
            PredictClassification.objects.bulk_create(obj_list)
            obj_list = None
            gc.collect()
        y = None
        y_pred = None
        y_conf = None
        gc.collect()
        return True
    except Exception as e:
        print('Prediction failed because: %s' % e)
        return False
示例#15
0
def main(products, year, month, save):
    from datacube_stats.utils.query import multi_product_list_cells
    import datacube
    from datacube.api import GridWorkflow

    query = {}
    if year is not None:
        if month is not None:
            query['time'] = ('{}-{}-01'.format(year, month),
                             '{}-{}-01'.format(year, month + 1))
        else:
            query['time'] = ('{}-01-01'.format(year), '{}-12-31'.format(year))

    dc = datacube.Datacube(app='dbg')
    gw = GridWorkflow(product=products[0], index=dc.index)

    click.echo('## Starting to run query', err=True)
    t_start = time.time()
    co_common, co_unmatched = multi_product_list_cells(products, gw, **query)
    t_took = time.time() - t_start
    click.echo('## Completed in {} seconds'.format(t_took), err=True)

    if save is not None:
        click.echo('## Saving data to {}'.format(save), err=True)
        with open(save, 'wb') as f:
            pickle.dump(dict(co_common=co_common, co_unmatched=co_unmatched),
                        f)
            f.close()
        click.echo(' done')

    click.echo('## Processing results,  ...wait', err=True)

    coverage = set(flat_map_ds(ds_to_key, co_common[0]))
    um = set(flat_map_ds(ds_to_key, co_unmatched[0]))

    # These tiles have both matched and unmatched data on the same solar day
    # It's significant cause these are the ones that will interfere with
    # masking if masking is done the "usual way"
    um_with_siblings = um - (um - coverage)

    click.echo('## Found {} matched records and {} unmatched'.format(
        len(coverage), len(um)))
    click.echo(
        '##   Of {} unmatched records {} are "dangerous" for masking'.format(
            len(um), len(um_with_siblings)))
    click.echo('##')

    def dump_unmatched_ds(ds, cell_idx, solar_day):
        k = ds_to_key(ds, cell_idx, solar_day)
        flag = '!' if k in coverage else '.'
        click.echo('{} {} {} {}'.format(k, flag, ds.id, ds.local_path))

    for (idx, product) in enumerate(products):
        click.echo('## unmatched ###########################')
        click.echo('## {}'.format(product))
        click.echo('########################################')
        flat_foreach_ds(dump_unmatched_ds, co_unmatched[idx])
示例#16
0
    def load_slice(i):
        loc = [slice(i, i + 1), slice(None), slice(None)]
        d = GridWorkflow.load(tile[loc], **kwargs)

        if mask_nodata:
            d = sensible_mask_invalid_data(d)

        # Load all masks and combine them all into one
        mask = None
        for (m_tile, flags, load_args), invert in zip(masks, inverts):
            m = GridWorkflow.load(m_tile[loc], **load_args)
            m, *other = m.data_vars.values()
            # TODO make use of make_mask_from_spec here
            m = make_mask(m, **flags)

            if invert:
                m = np.logical_not(m)

            if mask is None:
                mask = m
            else:
                mask &= m

        if mask_inplace or not mask_nodata:
            where = sensible_where_inplace
        else:
            where = sensible_where

        if mask is not None:
            # Apply mask in place if asked or if we already performed
            # conversion to float32, this avoids reallocation of memory and
            # hence increases the largest data set size one can load without
            # running out of memory
            d = where(d, mask)

        if geom is not None:
            d = where(d, geometry_mask([geom], d.geobox, invert=True))

        if src_idx is not None:
            d.coords['source'] = ('time', np.repeat(src_idx, d.time.size))

        return d
def list_all_cells(index, products, period, dt1, dt2):
    print("date range is FROM " + str(dt2) + " TO " + str(dt1))
    gw = GridWorkflow(index=index, product=products[0])
    my_cell_info = defaultdict(dict)
    pq_cell_info = defaultdict(dict)
    cell_list = []
    print(" database querying for listing all cells starts at " +
          str(datetime.now()))
    for prod in PQ_PRODUCTS:
        pq = gw.list_cells(product=prod, time=(dt2, dt1), group_by='solar_day')
        my_cell_info[prod] = pq
        pq_cell_info.update(pq)
    for prod in NBAR_PRODUCTS:
        data_info = gw.list_cells(product=prod,
                                  time=(dt2, dt1),
                                  group_by='solar_day')
        my_cell_info[prod] = data_info
    for k, v in pq_cell_info.items():
        cell_list.append(k)
    cell_list = ['({0},{1})'.format(a, b) for (a, b) in cell_list]
    print(" database query done for  all cells " + str(len(cell_list)) +
          str(datetime.now()))
    return cell_list, my_cell_info
示例#18
0
def run(tile, center_dt, path):
    """Basic datapreparation recipe 001

    Computes mean NDVI for a landsat collection over a given time frame

    Args:
        tile (tuple): Tuple of (tile indices, Tile object). Tile object can be
            loaded as xarray.Dataset using gwf.load()
        center_dt (datetime): Date to be used in making the filename
        path (str): Directory where files generated are to be written

    Return:
        str: The filename of the netcdf file created
    """
    try:
        center_dt = center_dt.strftime("%Y-%m-%d")
        nc_filename = os.path.join(
            path,
            'ndvi_mean_%d_%d_%s.nc' % (tile[0][0], tile[0][1], center_dt))
        if os.path.isfile(nc_filename):
            logger.warning(
                '%s already exists. Returning filename for database indexing',
                nc_filename)
            return nc_filename
        # Load Landsat sr
        sr = GridWorkflow.load(tile[1], dask_chunks={'x': 1667, 'y': 1667})
        # Compute ndvi
        sr['ndvi'] = (sr.nir - sr.red) / (sr.nir + sr.red) * 10000
        clear = masking.make_mask(sr.pixel_qa, clear=True)
        ndvi = sr.drop(
            ['pixel_qa', 'blue', 'red', 'green', 'nir', 'swir1', 'swir2'])
        ndvi_clear = ndvi.where(clear)
        # Run temporal reductions and rename DataArrays
        ndvi_mean = ndvi_clear.mean('time', keep_attrs=True)
        ndvi_mean['ndvi'].attrs['nodata'] = -9999
        ndvi_mean_int = ndvi_mean.apply(to_int)
        ndvi_mean_int.attrs['crs'] = sr.attrs['crs']
        write_dataset_to_netcdf(ndvi_mean_int,
                                nc_filename,
                                netcdfparams={'zlib': True})
        return nc_filename
    except Exception as e:
        logger.info('Tile (%d, %d) not processed. %s' %
                    (tile[0][0], tile[0][1], e))
        return None
示例#19
0
    def extract_pq_dataset(self, acq_min, acq_max):
        gw = GridWorkflow(index=self.index, product=self.products[0])
        ls_5 = defaultdict()
        ls_7 = defaultdict()
        ls_8 = defaultdict()
        pq = None

        for i, st in enumerate(self.products):
            pq_cell_info = defaultdict(dict)
            cell_info = defaultdict(dict)
            prod = None
            if st == 'ls5_nbar_albers':
                prod = 'ls5_pq_albers'
            elif st == 'ls7_nbar_albers':
                prod = 'ls7_pq_albers'
            else:
                prod = 'ls8_pq_albers'
            print ("my cell and sensor", self.cell, st )
            if len(self.odir) > 0:
                filepath = self.odir + '/' + 'TIDAL_' + ''.join(map(str, cell))  \
                       + "_MEDOID_" + str(self.per) + "_PERC_" + str(self.epoch) + "_EPOCH.nc"
                if os.path.isfile(filepath):
                   print ("file exists " + filepath)
                   continue
            if st == 'ls7_nbar_albers' and acq_max > datetime.strptime("2003-01-01", "%Y-%m-%d").date():
                print ("LS7 post 2003 Jan data is not included")
                continue
            if st == 'ls8_nbar_albers' and acq_max < datetime.strptime("2013-01-01", "%Y-%m-%d").date():
                print ("No data for LS8 and hence not searching")
                continue
            indexers = {'product':prod, 'time':(self.start_epoch, self.end_epoch)}
            if i != 0 and len(pq) > 0:
                import pdb; pdb.set_trace()
                pq[self.cell].sources = xr.concat([pq[self.cell].sources, 
                                                  list_gqa_filtered_cells(self.index, gw, pix_th=1, cell_index=eval(self.cell),
                                                                          **indexers)[self.cell].sources], dim='time')
            else:
                pq = list_gqa_filtered_cells(self.index, gw, pix_th=1, cell_index=eval(self.cell), **indexers)

        return pq, gw 
示例#20
0
def find_gaps(datacube, products, query, time_divs=None):
    """ Summary of gaps in the `products` compared pairwise. """
    products = list(set(products))
    assert len(products) != 0 and len(products) != 1, "no products to compare"

    grid_workflow = GridWorkflow(datacube.index, product=products[0])

    def mismatch_summary(product1, product2):
        """ Summarize mismatch info into a dictionary. """
        subqueries = subdivide_time_domain(time_divs=time_divs, **query)
        left, right = distribute(datacube, product1, product2, grid_workflow,
                                 subqueries)
        return {
            'products': [left.product, right.product],
            left.product: left.summary(),
            right.product: right.summary()
        }

    return [
        mismatch_summary(product1, product2)
        for product1, product2 in combinations(products, 2)
    ]
示例#21
0
def segment(tile, algorithm, segmentation_meta,
            band_list, extra_args):
    """Run a segmentation algorithm on tile

    Meant to be called within a dask.distributed.Cluster.map() over a list of tiles
    returned by GridWorkflow.list_cells
    Called in segment command line

    Args:
        tile: Datacube tile as returned by GridWorkflow.list_cells()
        algorithm (str): Name of the segmentation algorithm to apply
        segmentation_meta (madmex.models.SegmentationInformation.object): Django object
            relating to every segmentation object generated by this run
        band_list (list): Optional subset of bands of the product to use for running the segmentation.
        extra_args (dict): dictionary of additional arguments
    """
    # Load segment class
    try:
        module = import_module('madmex.segmentation.%s' % algorithm)
        Segmentation = module.Segmentation
    except ImportError as e:
        raise ValueError('Invalid model argument')

    try:
        # Load tile
        geoarray = GridWorkflow.load(tile[1], measurements=band_list)
        seg = Segmentation.from_geoarray(geoarray, **extra_args)
        seg.segment()
        # Try deallocating input array
        seg.array = None
        geoarray = None
        seg.polygonize()
        seg.to_db(segmentation_meta)
        gc.collect()
        return True
    except Exception as e:
        print(e)
        return False
示例#22
0
    def __call__(self, index, sources_spec,
                 date_ranges) -> Iterator[StatsTask]:
        """
        Generate the required tasks through time and across a spatial grid.

        Input region can be limited by specifying either/or both of `geopolygon` and `cell_index`, which
        will both result in only datasets covering the poly or cell to be included.

        :param index: Datacube Index
        :return:
        """
        workflow = GridWorkflow(index, grid_spec=self.grid_spec)

        for time_period in date_ranges:
            _LOG.info('Making output product tasks for time period: %s',
                      time_period)
            timer = MultiTimer().start('creating_tasks')
            created_tasks = 0

            if self.tile_indexes is not None:
                for tile_index in self.tile_indexes:
                    _LOG.debug('task for tile %s', tile_index)
                    for task in self.collect_tasks(workflow, time_period,
                                                   sources_spec, tile_index):
                        created_tasks += 1
                        yield task
            else:
                for task in self.collect_tasks(workflow, time_period,
                                               sources_spec):
                    created_tasks += 1
                    yield task

            # is timing it still appropriate here?
            timer.pause('creating_tasks')
            if created_tasks:
                _LOG.info('Created %s tasks for time period: %s. In: %s',
                          created_tasks, time_period, timer)
示例#23
0
def run(tile, center_dt, path):
    """Basic datapreparation recipe 001

    Combines temporal statistics of surface reflectance and ndvi with terrain
    metrics

    Args:
        tile (tuple): Tuple of (tile indices, Tile object). Tile object can be
            loaded as xarray.Dataset using gwf.load()
        center_dt (datetime): Date to be used in making the filename
        path (str): Directory where files generated are to be written

    Return:
        str: The filename of the netcdf file created
    """
    try:
        center_dt = center_dt.strftime("%Y-%m-%d")
        nc_filename = os.path.join(
            path,
            'madmex_001_%d_%d_%s.nc' % (tile[0][0], tile[0][1], center_dt))
        # Load Landsat sr
        if os.path.isfile(nc_filename):
            logger.warning(
                '%s already exists. Returning filename for database indexing',
                nc_filename)
            return nc_filename
        sr_0 = GridWorkflow.load(tile[1], dask_chunks={'x': 1667, 'y': 1667})
        # Load terrain metrics using same spatial parameters than sr
        dc = datacube.Datacube(app='landsat_madmex_001_%s' % randomword(5))
        terrain = dc.load(product='srtm_cgiar_mexico',
                          like=sr_0,
                          time=(datetime(1970, 1, 1), datetime(2018, 1, 1)),
                          dask_chunks={
                              'x': 1667,
                              'y': 1667
                          })
        dc.close()
        # Mask clouds, shadow, water, ice,... and drop qa layer
        clear = masking.make_mask(sr_0.pixel_qa,
                                  cloud=False,
                                  cloud_shadow=False,
                                  snow=False)
        sr_1 = sr_0.where(clear)
        sr_2 = sr_1.drop('pixel_qa')
        # Convert Landsat data to float (nodata values are converted to np.Nan)
        sr_3 = sr_2.apply(func=to_float, keep_attrs=True)
        # Compute ndvi
        sr_3['ndvi'] = ((sr_3.nir - sr_3.red) / (sr_3.nir + sr_3.red)) * 10000
        sr_3['ndvi'].attrs['nodata'] = -9999
        # Run temporal reductions and rename DataArrays
        sr_mean = sr_3.mean('time', keep_attrs=True, skipna=True)
        sr_mean.rename(
            {
                'blue': 'blue_mean',
                'green': 'green_mean',
                'red': 'red_mean',
                'nir': 'nir_mean',
                'swir1': 'swir1_mean',
                'swir2': 'swir2_mean',
                'ndvi': 'ndvi_mean'
            },
            inplace=True)
        sr_min = sr_3.min('time', keep_attrs=True, skipna=True)
        sr_min.rename(
            {
                'blue': 'blue_min',
                'green': 'green_min',
                'red': 'red_min',
                'nir': 'nir_min',
                'swir1': 'swir1_min',
                'swir2': 'swir2_min',
                'ndvi': 'ndvi_min'
            },
            inplace=True)
        sr_max = sr_3.max('time', keep_attrs=True, skipna=True)
        sr_max.rename(
            {
                'blue': 'blue_max',
                'green': 'green_max',
                'red': 'red_max',
                'nir': 'nir_max',
                'swir1': 'swir1_max',
                'swir2': 'swir2_max',
                'ndvi': 'ndvi_max'
            },
            inplace=True)
        sr_std = sr_3.std('time', keep_attrs=True, skipna=True)
        sr_std.rename(
            {
                'blue': 'blue_std',
                'green': 'green_std',
                'red': 'red_std',
                'nir': 'nir_std',
                'swir1': 'swir1_std',
                'swir2': 'swir2_std',
                'ndvi': 'ndvi_std'
            },
            inplace=True)
        # Merge dataarrays
        combined = xr.merge([
            sr_mean.apply(to_int),
            sr_min.apply(to_int),
            sr_max.apply(to_int),
            sr_std.apply(to_int), terrain
        ])
        combined.attrs['crs'] = sr_0.attrs['crs']
        write_dataset_to_netcdf(combined, nc_filename)
        return nc_filename
    except Exception as e:
        logger.warning('Tile (%d, %d) not processed. %s' %
                       (tile[0][0], tile[0][1], e))
        return None
示例#24
0
def detect_and_classify_change(tiles, algorithm, change_meta, band_list, mmu,
                               lc_pre, lc_post, extra_args,
                               filter_labels=True):
    """Run a change detection algorithm between two tiles, classify the results and write to the database

    Meant to be called within a dask.distributed.Cluster.map() over a list of
    (tile_index, [tile0, tile1]) tupples generated by two calls to gwf_query
    Called in detect_change command line

    Args:
        tiles (tuple): Tuple of (tile_index, [tile0, tile1]). Tiles are Datacube
            tiles as returned by GridWorkflow.list_cells()
        algorithm (str): Name of the change detection algorithm to use
        change_meta (madmex.models.ChangeInformation): Django object containing change
            objects meta information. Resulting from a call to ``get_or_create``
        band_list (list): Optional subset of bands of the product to use for running
            the change detection
        mmu (float or None): Minimum mapping unit in the unit of the tile crs
            (e.g.: squared meters, squared degrees, ...) to apply for filtering
            small change objects
        lc_pre (str): Name of the anterior land cover map to use for change
            classification
        lc_post (str): Name of the post land cover map to use for change
            classification
        extra_args (dict): dictionary of additional arguments
        filter_labels (bool): Whether to apply a filter to remove objects with same
            pre and post label. Defaults to True, in which case objects with same
            label are discarded
    """
    # Load change detection class
    try:
        module = import_module('madmex.lcc.bitemporal.%s' % algorithm)
        BiChange = module.BiChange
    except ImportError as e:
        raise ValueError('Invalid algorithm argument')

    try:
        # Load geoarrays
        geoarray_pre = GridWorkflow.load(tiles[1][0], measurements=band_list)
        BiChange_pre = BiChange.from_geoarray(geoarray_pre, **extra_args)
        geoarray_post = GridWorkflow.load(tiles[1][1], measurements=band_list)
        BiChange_post = BiChange.from_geoarray(geoarray_post)
        # Run change detection
        BiChange_pre.run(BiChange_post)
        # Apply mmu filter
        if mmu is not None:
            BiChange_pre.filter_mmu(mmu)
        # Exit function if there are no changes left
        if BiChange_pre.change_array.sum() == 0:
            return True
        # Load pre and post land cover map as feature collections
        fc_pre = BiChange_pre.read_land_cover(lc_pre)
        fc_post = BiChange_pre.read_land_cover(lc_post)
        # Generate feature collection of labelled change objects
        fc_change = BiChange_pre.label_change(fc_pre, fc_post)
        # Optionally filter objects with same pre and post label
        if filter_labels:
            fc_change = BiChange.filter_no_change(fc_change)
        # Write that feature collection to the database
        BiChange_pre.to_db(fc=fc_change, meta=change_meta, pre_name=lc_pre,
                       post_name=lc_post)
        # Deallocate large objects and run gc.collect
        geoarray_pre = None
        geoarray_post = None
        BiChange_pre = None
        BiChange_post = None
        fc_pre = None
        fc_post = None
        fc_change = None
        gc.collect()
        return True
    except Exception as e:
        print('Change detection failed because: %s' % e)
        return False
示例#25
0
from rasterio.features import rasterize
from madmex.io.vector_db import VectorDb
from rasterio.features import rasterize
import datacube
from datacube.api import GridWorkflow
from pprint import pprint
from datetime import datetime
from affine import Affine

# Load a test dataset
dc = datacube.Datacube()
gw = GridWorkflow(dc.index, product='ls8_espa_mexico')
tile_dict = gw.list_cells(product='ls8_espa_mexico',
                          x=(-104, -102),
                          y=(19, 21),
                          time=(datetime(2017, 1, 1), datetime(2017, 2, 1)))
tile_list = list(tile_dict.items())
sr = gw.load(tile_list[3][1])

# Visualize Dataset metadata
print(sr)

# Load training data corresponding to that dataset
db = VectorDb()
fc = db.load_training_from_dataset(sr)

# Visualize first element of feature collection
pprint(fc[0])

# Rasterize the feature collection
geom_list = [x['geometry'] for x in fc]
示例#26
0
def predict_pixel_tile(tile, model_id, outdir=None):
    """Run a model in prediction mode and generates a raster file written to disk

    Meant to be called within a dask.distributed.Cluster.map() over a list of tiles
    returned by GridWorkflow.list_cells
    Called in model_predict command line

    Args:
        tile: Datacube tile as returned by GridWorkflow.list_cells()
        model_id (str): Database identifier of trained model to use. The model
            must have been trained against a numeric dependent variable.
            (See --encode flag in model_fit command line)
        outdir (str): Directory where output data should be written. Only makes sense
            when generating unregistered geotiffs. The directory must already exist,
            it is therefore a good idea to generate it in the command line function
            before sending the tasks

    Return:
        str: The function is used for its side effect of generating a predicted
        array and writting it to a raster file (GeoTiff or NetCDF). OPtionally the file is registered as a storage unit in the datacube database.
    """
    # TODO: How to handle data type. When ran in classification mode int16 would always
    # be fine, but this function could potentially also be ran in regression mode
    try:
        # Load model class corresponding to the right model
        trained_model = BaseModel.from_db(model_id)
        try:
            # Avoid opening several threads in each process
            trained_model.model.n_jobs = 1
        except Exception as e:
            pass
        # Generate filename
        filename = os.path.join(outdir, 'prediction_%s_%d_%d.tif' % (model_id, tile[0][0], tile[0][1]))
        # Load tile
        xr_dataset = GridWorkflow.load(tile[1])
        # Convert it to float?
        # xr_dataset = xr_dataset.apply(func=to_float, keep_attrs=True)
        # Transform file to nd array
        arr_3d = xr_dataset.to_array().squeeze().values
        arr_3d = np.moveaxis(arr_3d, 0, 2)
        shape_2d = (arr_3d.shape[0] * arr_3d.shape[1], arr_3d.shape[2])
        arr_2d = arr_3d.reshape(shape_2d)
        # predict
        predicted_array = trained_model.predict(arr_2d)
        # Reshape back to 2D
        predicted_array = predicted_array.reshape((arr_3d.shape[0], arr_3d.shape[1]))
        # Write array to geotiff
        rasterio_meta = {'width': predicted_array.shape[1],
                         'height': predicted_array.shape[0],
                         'affine': xr_dataset.affine,
                         'crs': xr_dataset.crs.crs_str,
                         'count': 1,
                         'dtype': 'int16',
                         'compress': 'lzw',
                         'driver': 'GTiff'}
        with rasterio.open(filename, 'w', **rasterio_meta) as dst:
            dst.write(predicted_array.astype('int16'), 1)
        # Coerce back to xarray with all spatial parameters properly set
        # xr_out = xr_dataset.drop(xr_dataset.data_vars)
        # Build output filename
        # xr_out['predicted'] = (('y', 'x'), predicted_array)
        # Write to filename
        # write_geotiff(filename=filename, dataset=xr_out)
        # Register to database
        return filename
    except Exception as e:
        print(e)
        return None
示例#27
0
def gwf_query(product, lat=None, long=None, region=None, begin=None, end=None,
              view=True):
    """Run a spatial query on a datacube product using either coordinates or a region name

    Wrapper function to call at the begining of nearly all spatial processing command lines

    Args:
        product (str): Name of an ingested datacube product. The product to query
        lat (tuple): OPtional. For coordinate based spatial query. Tuple of min and max
            latitudes in decimal degreees.
        long (tuple): OPtional. For coordinate based spatial query. Tuple of min and max
            longitudes in decimal degreees.
        region (str): Optional name of a region or country whose geometry is present in the database
            region  or country table. Overrides lat and long when present (not None).
            Countries must be queried using ISO code (e.g.: 'MEX' for Mexico)
        begin (str): Date string in the form '%Y-%m-%d'. For temporally bounded queries
        end (str): Date string in the form '%Y-%m-%d'. For temporally bounded queries
        view (bool): Returns a view instead of the dictionary returned by ``GridWorkflow.list_cells``.
            Useful when the output is be used directly as an iterable (e.g. in ``distributed.map``)
            Default to True

    Returns:
        dict or view: Dictionary (view) of Tile index, Tile key value pair

    Example:

        >>> from madmex.wrappers import gwf_query

        >>> # Using region name, time unbounded
        >>> tiles_list = gwf_query(product='ls8_espa_mexico', region='Jalisco')
        >>> # Using region name, time windowed
        >>> tiles_list = gwf_query(product='ls8_espa_mexico', region='Jalisco',
        ...                        begin = '2017-01-01', end='2017-03-31')
        >>> # Using lat long box, time windowed
        >>> tiles_list = gwf_query(product='ls8_espa_mexico', lat=[19, 22], long=[-104, -102],
        ...                        begin = '2017-01-01', end='2017-03-31')
    """
    query_params = {'product': product}
    if region is not None:
       # Query database and build a datacube.utils.Geometry(geopolygon)
       try:
           query_set = Country.objects.get(name=region)
       except Country.DoesNotExist:
           query_set = Region.objects.get(name=region)
       region_json = json.loads(query_set.the_geom.geojson)
       crs = CRS('EPSG:%d' % query_set.the_geom.srid)
       geom = Geometry(region_json, crs)
       query_params.update(geopolygon=geom)
    elif lat is not None and long is not None:
        query_params.update(x=long, y=lat)
    else:
        raise ValueError('Either a region name or a lat and long must be provided')

    if begin is not None and end is not None:
        begin = datetime.strptime(begin, "%Y-%m-%d")
        end = datetime.strptime(end, "%Y-%m-%d")
        query_params.update(time=(begin, end))

    # GridWorkflow object
    dc = datacube.Datacube()
    gwf = GridWorkflow(dc.index, product=product)
    tile_dict = gwf.list_cells(**query_params)
    # Iterable (dictionary view (analog to list of tuples))
    if view:
        tile_dict = tile_dict.items()
    return tile_dict
示例#28
0
    for name, var in dataset_to_transform.data_vars.items():
        new_df[name] = np.reshape(var.data, -1)

# Points at the edge of the image could return empty arrays (all 0's) - this will remove any columns to which this applies
    new_df = new_df.dropna(axis=1, how='all')

    return new_df

if __name__ == '__main__':

    # Create datacube object
    dc = datacube.Datacube()

    # Create GridWorkflow object so we can work with tiles
    gw = GridWorkflow(dc.index, product=sref_products[-1])

    # List to store the three datasets (LS5, LS7, LS8)
    sref_ds = []

    # The key represents which tile we are using
    key = (5, -28)

    # Need to fetch the tiles for each product seperately
    for product in sref_products:

        gw = GridWorkflow(dc.index, product=product)

        # Get the list of tiles (one for each time point) for this product
        tile_list = gw.list_tiles(product=product, cell_index=key)
示例#29
0
def run(tile, center_dt, path):
    """Basic datapreparation recipe 001

    Combines temporal statistics of surface reflectance and ndvi with terrain
    metrics

    Args:
        tile (tuple): Tuple of (tile indices, Tile object). Tile object can be
            loaded as xarray.Dataset using gwf.load()
        center_dt (datetime): Date to be used in making the filename
        path (str): Directory where files generated are to be written

    Return:
        str: The filename of the netcdf file created
    """
    try:
        center_dt = center_dt.strftime("%Y-%m-%d")
        nc_filename = os.path.join(path, 's2_20m_001_%d_%d_%s.nc' % (tile[0][0], tile[0][1], center_dt))
        # Load Landsat sr
        if os.path.isfile(nc_filename):
            logger.warning('%s already exists. Returning filename for database indexing', nc_filename)
            return nc_filename
        sr_0 = GridWorkflow.load(tile[1], dask_chunks={'x': 1000, 'y': 1000})
        sr_0 = sr_0.apply(func=to_float, keep_attrs=True)
        # Load terrain metrics using same spatial parameters than sr
        dc = datacube.Datacube(app = 's2_20m_001_%s' % randomword(5))
        terrain = dc.load(product='srtm_cgiar_mexico', like=sr_0,
                          time=(datetime(1970, 1, 1), datetime(2018, 1, 1)),
                          dask_chunks={'x': 1000, 'y': 1000})
        dc.close()
        # Keep clear pixels (2: Dark features, 4: Vegetation, 5: Not vegetated,
        # 6: Water, 7: Unclassified, 11: Snow/Ice)
        sr_1 = sr_0.where(sr_0.pixel_qa.isin([2,4,5,6,7,8,11]))
        sr_1 = sr_1.drop('pixel_qa')
        # Compute ndvi
        sr_1['ndvi'] = ((sr_1.nir - sr_1.red) / (sr_1.nir + sr_1.red)) * 10000
        sr_1['ndvi'].attrs['nodata'] = 0
        # Compute ndmi
        sr_1['ndmi'] = ((sr_1.nir - sr_1.swir1) / (sr_1.nir + sr_1.swir1)) * 10000
        sr_1['ndmi'].attrs['nodata'] = 0
        # Run temporal reductions and rename DataArrays
        sr_mean = sr_1.mean('time', keep_attrs=True, skipna=True)
        sr_mean.rename({'blue': 'blue_mean',
                        'green': 'green_mean',
                        'red': 'red_mean',
                        're1': 're1_mean',
                        're2': 're2_mean',
                        're3': 're3_mean',
                        'nir': 'nir_mean',
                        'swir1': 'swir1_mean',
                        'swir2': 'swir2_mean',
                        'ndmi': 'ndmi_mean',
                        'ndvi': 'ndvi_mean'}, inplace=True)
        # Compute min/max/std only for vegetation indices
        ndvi_max = sr_1.ndvi.max('time', keep_attrs=True, skipna=True)
        ndvi_max = ndvi_max.rename('ndvi_max')
        ndvi_max.attrs['nodata'] = 0
        ndvi_min = sr_1.ndvi.min('time', keep_attrs=True, skipna=True)
        ndvi_min = ndvi_min.rename('ndvi_min')
        ndvi_min.attrs['nodata'] = 0
        # ndmi
        ndmi_max = sr_1.ndmi.max('time', keep_attrs=True, skipna=True)
        ndmi_max = ndmi_max.rename('ndmi_max')
        ndmi_max.attrs['nodata'] = 0
        ndmi_min = sr_1.ndmi.min('time', keep_attrs=True, skipna=True)
        ndmi_min = ndmi_min.rename('ndmi_min')
        ndmi_min.attrs['nodata'] = 0
        # Merge dataarrays
        combined = xr.merge([sr_mean.apply(to_int),
                             to_int(ndvi_max),
                             to_int(ndvi_min),
                             to_int(ndmi_max),
                             to_int(ndmi_min),
                             terrain])
        combined.attrs['crs'] = sr_0.attrs['crs']
        combined = combined.compute()
        write_dataset_to_netcdf(combined, nc_filename)
        return nc_filename
    except Exception as e:
        logger.warning('Tile (%d, %d) not processed. %s' % (tile[0][0], tile[0][1], e))
        return None
示例#30
0
def multi_product_list_cells(products,
                             gw,
                             cell_index=None,
                             product_query=None,
                             **query):
    """This is similar to GridWorkflow.list_cells but generalised to multiple
    products. Only datasets that are available in all of the products are
    reported.

    Datasets that do not have a full set across all products are returned in a
    separate group.


    products      -- list of product names
    gw            -- Preconfigured GridWorkflow object
    cell_index    -- Limit search area to a single cell
    product_query -- Product specific query, dict product_name => product specific query
    **query       -- Common query parameters across all products

    Returns:

    co_common     -- Cell observation that have full set across products
    co_unmatched  -- Cell observations where at least one product is missing

    Type of `co_common, co_unmatched` is list of dictionaries of tiles.

    `type(co_common[product_idx:Int][cell_idx:(Int,Int)]) == datacube.api.Tile`

    """
    if product_query is None:
        product_query = {}

    empty_cell = dict(datasets=[], geobox=None)
    co_common = [dict() for _ in products]
    co_unmatched = [dict() for _ in products]

    group_by = query_group_by(**query)

    obs = [
        gw.cell_observations(product=product,
                             cell_index=cell_index,
                             **product_query.get(product, {}),
                             **query) for product in products
    ]

    # set of all cell indexes found across all products
    all_cell_idx = set(reduce(list.__add__, [list(o.keys()) for o in obs]))

    def cell_is_empty(c):
        return len(c['datasets']) == 0

    for cidx in all_cell_idx:
        common, unmatched = common_obs_per_cell(
            *[o.get(cidx, empty_cell) for o in obs])

        for i in range(len(products)):
            if cidx in obs[i]:
                if not cell_is_empty(common[i]):
                    co_common[i][cidx] = common[i]

                if not cell_is_empty(unmatched[i]):
                    co_unmatched[i][cidx] = unmatched[i]

    co_common = [
        GridWorkflow.group_into_cells(c, group_by=group_by) for c in co_common
    ]
    co_unmatched = [
        GridWorkflow.group_into_cells(c, group_by=group_by)
        for c in co_unmatched
    ]

    return co_common, co_unmatched