Ejemplo n.º 1
0
def test_woffles(query, expected):
    dc = Datacube(app='test_wofls')

    bands = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']  # inputs needed from EO data)
    source = dc.load(product='ls8_nbar_albers', measurements=bands, **query)
    pq = dc.load(product='ls8_pq_albers', like=source)
    dsm = dc.load(product='dsm1sv10', like=source, time=('1900-01-01', '2100-01-01'), resampling='cubic')

    wofls_output = woffles(*(x.isel(time=0) for x in [source, pq, dsm]))

    assert (wofls_output == expected).all()
def post_processing(predicted):
    """
    filter prediction results with post processing filters.
    
    Simplified from production code to skip
    segmentation, probability, and mode calcs

    """

    dc = Datacube(app='whatever')

    predict = predicted.Predictions

    #--Post process masking---------------------------------------------------------------
    #print("  masking with AEZ,WDPA,WOfS,slope & elevation")

    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file('data/Sahel.geojson')
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    predict = predict.where(mask, 0)

    # mask with WDPA
    #     url_wdpa="s3://deafrica-input-datasets/protected_areas/WDPA_southern.tif"
    #     wdpa=rio_slurp_xarray(url_wdpa, gbox=predicted.geobox)
    #     wdpa = wdpa.astype(bool)
    #     predict = predict.where(~wdpa, 0)

    #mask with WOFS
    wofs = dc.load(product='wofs_ls_summary_annual',
                   like=predicted.geobox,
                   time=('2019'))
    wofs = wofs.frequency > 0.2  # threshold
    predict = predict.where(~wofs, 0)

    #mask steep slopes
    url_slope = "https://deafrica-input-datasets.s3.af-south-1.amazonaws.com/srtm_dem/srtm_africa_slope.tif"
    slope = rio_slurp_xarray(url_slope, gbox=predicted.geobox)
    slope = slope > 50
    predict = predict.where(~slope, 0)

    #mask where the elevation is above 3600m
    elevation = dc.load(product='dem_srtm', like=predicted.geobox)
    elevation = elevation.elevation > 3600  # threshold
    predict = predict.where(~elevation.squeeze(), 0)

    #set dtype
    predict = predict.astype(np.int8)

    return predict
Ejemplo n.º 3
0
def post_processing(
    predicted: xr.Dataset,

) -> xr.DataArray:
    """
    filter prediction results with post processing filters.
    :param predicted: The prediction results

    """
    
    dc = Datacube(app='whatever')
 
    #grab predictions and proba for post process filtering
    predict=predicted.Predictions

    
    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file('data/Western.geojson')
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    predict = predict.where(mask,0)
    
    # mask with WDPA
    url_wdpa="s3://deafrica-input-datasets/protected_areas/WDPA_western.tif"
    wdpa=rio_slurp_xarray(url_wdpa, gbox=predicted.geobox)
    wdpa = wdpa.astype(bool)
    predict = predict.where(~wdpa, 0)
    
    #mask with WOFS
    wofs=dc.load(product='ga_ls8c_wofs_2_summary',like=predicted.geobox)
    wofs=wofs.frequency > 0.2 # threshold
    predict=predict.where(~wofs, 0)

    #mask steep slopes
    url_slope="https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif"
    slope=rio_slurp_xarray(url_slope, gbox=predicted.geobox)
    slope=slope > 35
    predict=predict.where(~slope, 0)

    #mask where the elevation is above 3600m
    elevation=dc.load(product='dem_srtm', like=predicted.geobox)
    elevation=elevation.elevation > 3600 # threshold
    predict=predict.where(~elevation.squeeze(), 0)
    
    #set dtype
    predict=predict.astype(np.int8)

    return predict
Ejemplo n.º 4
0
def calculate_index_task(params):
    item = params.get('item')
    index = params.get('index', 'rgb')
    dc = Datacube(config="datacube.conf")
    product = "ls8_level1_usgs"
    x = (item["bbox"][0], item["bbox"][2])
    y = (item["bbox"][1], item["bbox"][3])
    time = item["properties"]["datetime"].split("T")[0]
    measurements = ["band_2", "band_3", "band_4"]

    query = {
        'x': x,
        'y': y,
        'time': time,
        'measurements': ['nbart_red', 'nbart_green', 'nbart_blue'],
        'output_crs': 'EPSG:4326',
        'resolution': (-0.001, 0.001),
    }

    ds = dc.load(product=product, **query)
    print(ds)
    rgb_da = ds.to_array()
    suffix = 'rgb'
    filename = f'{item["id"]}_{suffix}.tif'
    path = config.STATIC_DIR / filename
    write_cog(geo_im=rgb_da, fname='rgb.tif', overwrite=True)
    return {"success": True, "url": str(path)}
Ejemplo n.º 5
0
def xadataset_from_odcdataset(datasets: Union[List[ODCDataset],
                                              ODCDataset] = None,
                              ids: Union[List[UUID], UUID] = None,
                              measurements: List[str] = None) -> xa.Dataset:
    """ Loads a xaDataset from ODCDatasets or ODCDataset ids
     :param datasets: ODCDataset(s), optional
     :param ids: ODCDataset id(s), optional
     :param measurements: list of measurements/bands to load, optional
     :return: xa.Dataset containing given ODCDatasets or IDs """

    dc = Datacube(app="dataset_from_ODCDataset")

    if not datasets:
        if not isinstance(ids, list):
            ids = [ids]
        datasets = [dc.index.datasets.get(id_) for id_ in ids]

    if not isinstance(datasets, list):
        datasets = [datasets]

    product_name = datasets[0].metadata_doc["product"]["name"]
    crs = datasets[0].crs
    res = (10, -10)  # TODO: handle other resolutions

    ds = dc.load(product=product_name,
                 dask_chunks={},
                 measurements=measurements,
                 output_crs=str(crs),
                 resolution=res,
                 datasets=datasets)
    return ds
Ejemplo n.º 6
0
def post_processing(
    data: xr.Dataset,
    predicted: xr.Dataset,
    config: FeaturePathConfig,
    geobox_used: GeoBox,
) -> xr.DataArray:
    """
    filter prediction results with post processing filters.
    :param data: raw data with all features to run prediction
    :param predicted: The prediction results
    :param config:  FeaturePathConfig configureation
    :param geobox_used: Geobox used to generate the prediciton feature
    :return: only predicted binary class label
    """
    # post prediction filtering
    predict = predicted.Predictions
    query = config.query.copy()
    # Update dc query with geometry
    # geobox_used = self.geobox_dict[(x, y)]
    query["geopolygon"] = Geometry(geobox_used.extent.geom,
                                   crs=geobox_used.crs)

    dc = Datacube(app=__name__)
    # mask with WOFS
    # wofs_query = query.pop("measurements")
    wofs = dc.load(product="ga_ls8c_wofs_2_summary", **query)
    wofs = wofs.frequency > 0.2  # threshold
    predict = predict.where(~wofs, 0)

    # mask steep slopes
    slope = data.slope > 35
    predict = predict.where(~slope, 0)

    # mask where the elevation is above 3600m
    query.pop("time")
    elevation = dc.load(product="srtm", **query)
    elevation = elevation.elevation > 3600
    predict = predict.where(~elevation.squeeze(), 0)
    return predict
Ejemplo n.º 7
0
def rgb_task2(item):
    dc = Datacube(config="datacube.conf")
    product = "ls8_level1_usgs"
    time = item["properties"]["datetime"].split("T")[0]
    x = (item["bbox"][0], item["bbox"][2])
    y = (item["bbox"][1], item["bbox"][3])
    measurements = ["B2"]
    ds = dc.load(product=product,
                 measurements=measurements,
                 time=time,
                 x=x,
                 y=y,
                 output_crs='EPSG:4326',
                 resolution=(-0.001, 0.001))
    suffix = '_'.join(measurements)
    filename = f'{item["id"]}_{suffix}.tif'

    path = write_cog(
        ds.to_array(),
        Path('/static') / filename,
    )
    return {"success": True, "url": str(path)}
Ejemplo n.º 8
0
def cli(product, input_prefix, location, verbose):
    """
    Generate mosaic overviews of the stats data.

    An intermediate cache file is generated and stored in the output location
    during this process.
    Note: The input bucket must be public otherwise the data can not be listed.
    """

    product = product_from_yaml(product)
    if verbose:
        print(f"Preparing mosaics for {product.name} product")

    dss = s3_fetch_dss(input_prefix, product, glob="*.json")
    cache = create_cache(f"{product.name}.db")

    if verbose:
        print(f"Writing {location}/{product.name}.db")

    cache = create_cache(f"{location}/{product.name}.db")
    cache.bulk_save(dss)
    if verbose:
        print(f"Found {cache.count:,d} datasets")

    dc = Datacube()
    dss = list(cache.get_all())
    xx = dc.load(
        datasets=dss,
        dask_chunks={
            "x": 3200,
            "y": 3200
        },
        resolution=(-120, 120),
        measurements=["red", "green", "blue"],
    )

    save(xx, location, product.name, verbose)
Ejemplo n.º 9
0
def post_processing(predicted):
    """
    filter prediction results with post processing filters.
    :param predicted: The prediction results

    """

    dc = Datacube(app='whatever')

    # grab predictions and proba for post process filtering
    predict = predicted.Predictions
    #     proba = predicted.Probabilities
    #     proba = proba.where(predict == 1, 100 - proba)  # crop proba only

    #     #------image seg and filtering -------------
    #     # write out ndvi for image seg
    #     ndvi = assign_crs(predicted[["NDVI_S1", "NDVI_S2"]],
    #                       crs=predicted.geobox.crs)

    #     # call function with dask delayed
    #     filtered = image_segmentation(ndvi, predict)

    #     # convert delayed object to dask array
    #     filtered = dask.array.from_delayed(filtered.squeeze(),
    #                                        shape=predict.shape,
    #                                        dtype=np.int8)

    #     # convert dask array to xr.Datarray
    #     filtered = xr.DataArray(filtered,
    #                             coords=predict.coords,
    #                             attrs=predict.attrs)

    # --Post process masking------------------------------------------------

    # merge back together for masking
    ds = xr.Dataset({"mask":
                     predict})  #, "prob": proba, "filtered": filtered})

    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file(
        'https://github.com/digitalearthafrica/crop-mask/blob/main/testing/eastern_cropmask/data/Eastern.geojson?raw=true'
    )
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    mask = mask.chunk({})
    ds = ds.where(mask, 0)

    # mask with WDPA
    wdpa = rio_slurp_xarray(
        "s3://deafrica-input-datasets/protected_areas/WDPA_eastern.tif",
        gbox=predicted.geobox)
    wdpa = wdpa.chunk({})
    wdpa = wdpa.astype(bool)
    ds = ds.where(~wdpa, 0)

    # mask with WOFS
    wofs = dc.load(product="ga_ls8c_wofs_2_summary",
                   like=predicted.geobox,
                   dask_chunks={})
    wofs = wofs.frequency > 0.2  # threshold
    ds = ds.where(~wofs, 0)

    # mask steep slopes
    slope = rio_slurp_xarray(
        'https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif',
        gbox=predicted.geobox)
    slope = slope.chunk({})
    slope = slope > 35
    ds = ds.where(~slope, 0)

    # mask where the elevation is above 3600m
    elevation = dc.load(product="dem_srtm",
                        like=predicted.geobox,
                        dask_chunks={})
    elevation = elevation.elevation > 3600  # threshold
    ds = ds.where(~elevation.squeeze(), 0)

    return ds.squeeze()
Ejemplo n.º 10
0
def create_mosaic(
    dc: Datacube,
    product: str,
    out_product: str,
    time: Tuple[str, str],
    time_str: str,
    bands: Tuple[str],
    s3_output_root: str,
    split_bands: bool = False,
    resolution: int = 120,
    overwrite: bool = False,
):
    log = setup_logging()
    log.info(f"Creating mosaic for {product} over {time}")

    client = start_local_dask()

    assets = {}
    data = dc.load(
        product=product,
        time=time,
        resolution=(-resolution, resolution),
        dask_chunks={"x": 2048, "y": 2048},
        measurements=bands,
    )

    # This is a bad idea, we run out of memory
    # data.persist()

    if not split_bands:
        log.info("Creating a single tif file")
        out_file = _get_path(s3_output_root, out_product, time_str, "tif")
        exists = s3_head_object(out_file) is not None
        skip_writing = not (not exists or overwrite)
        try:
            asset, _ = _save_opinionated_cog(
                data,
                out_file,
                skip_writing=skip_writing,
            )
        except ValueError:
            log.exception(
                "Failed to create COG, please check that you only have one timestep in the period."
            )
            exit(1)
        assets[bands[0]] = asset
        if skip_writing:
            log.info(f"File exists, and overwrite is False. Not writing {out_file}")
        else:
            log.info(f"Finished writing: {asset.href}")
    else:
        log.info("Creating multiple tif files")

        for band in bands:
            out_file = _get_path(
                s3_output_root, out_product, time_str, "tif", band=band
            )
            exists = s3_head_object(out_file) is not None
            skip_writing = not (not exists or overwrite)

            try:
                asset, band = _save_opinionated_cog(
                    data=data,
                    out_file=out_file,
                    band=band,
                    skip_writing=skip_writing,
                )
            except ValueError:
                log.exception(
                    "Failed to create COG, please check that you only have one timestep in the period."
                )
                exit(1)
            assets[band] = asset
            if skip_writing:
                log.info(f"File exists, and overwrite is False. Not writing {out_file}")
            else:
                log.info(f"Finished writing: {asset.href}")
                # Aggressively heavy handed, but we get memory leaks otherwise
                client.restart()

    out_stac_file = _get_path(s3_output_root, out_product, time_str, "stac-item.json")
    item = create_stac_item(
        assets[bands[0]].href,
        id=f"{product}_{time_str}",
        assets=assets,
        with_proj=True,
        properties={
            "odc:product": out_product,
            "start_datetime": f"{time[0]}T00:00:00Z",
            "end_datetime": f"{time[1]}T23:59:59Z",
        },
    )
    item.set_self_href(out_stac_file)

    log.info(f"Writing STAC: {out_stac_file}")
    client = s3_client(aws_unsigned=False)
    s3_dump(
        data=json.dumps(item.to_dict(), indent=2),
        url=item.self_href,
        ACL="bucket-owner-full-control",
        ContentType="application/json",
        s3=client,
    )
Ejemplo n.º 11
0
def get_data_opensource_shapefile(prod_info, acq_min, acq_max, shapefile,
                                  no_partial_scenes):

    datacube_config = prod_info[0]
    source_prod = prod_info[1]
    source_band_list = prod_info[2]
    mask_band = prod_info[3]

    if datacube_config != 'default':
        remotedc = Datacube(config=datacube_config)
    else:
        remotedc = Datacube()

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        with fiona.open(shapefile) as shapes:
            crs = geometry.CRS(shapes.crs_wkt)
            first_geometry = next(iter(shapes))['geometry']
            geom = geometry.Geometry(first_geometry, crs=crs)

            return_data = {}
            data = xr.Dataset()

            if source_prod != '':
                # get a sample dataset to decide the target epsg
                fd_query = {'time': (acq_min, acq_max), 'geopolygon': geom}
                sample_fd_ds = remotedc.find_datasets(product=source_prod,
                                                      group_by='solar_day',
                                                      **fd_query)

                if (len(sample_fd_ds)) > 0:
                    # decidce pixel size for output data
                    pixel_x, pixel_y = get_pixel_size(sample_fd_ds[0],
                                                      source_band_list)
                    log.info(
                        'Output pixel size for product {}: x={}, y={}'.format(
                            source_prod, pixel_x, pixel_y))

                    # get target epsg from metadata
                    target_epsg = get_epsg(sample_fd_ds[0])
                    log.info('CRS for product {}: {}'.format(
                        source_prod, target_epsg))

                    query = {
                        'time': (acq_min, acq_max),
                        'geopolygon': geom,
                        'output_crs': target_epsg,
                        'resolution': (-pixel_y, pixel_x),
                        'measurements': source_band_list
                    }

                    if 's2' in source_prod:
                        data = remotedc.load(product=source_prod,
                                             group_by='solar_day',
                                             **query)
                    else:
                        data = remotedc.load(product=source_prod,
                                             align=(pixel_x / 2.0,
                                                    pixel_y / 2.0),
                                             group_by='solar_day',
                                             **query)

                    # remove cloud and nodta
                    data = remove_cloud_nodata(source_prod, data, mask_band)

                    if data.data_vars:
                        mask = geometry_mask([geom], data.geobox, invert=True)
                        data = data.where(mask)

                    if no_partial_scenes:
                        # calculate valid data percentage
                        data = only_return_whole_scene(data)

                return_data = {
                    source_prod: {
                        'data': data,
                        'mask_band': mask_band,
                        'find_list': sample_fd_ds
                    }
                }

    return return_data
Ejemplo n.º 12
0
def get_data_opensource(prod_info, input_lon, input_lat, acq_min, acq_max,
                        window_size, no_partial_scenes):

    datacube_config = prod_info[0]
    source_prod = prod_info[1]
    source_band_list = prod_info[2]
    mask_band = prod_info[3]

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        if datacube_config != 'default':
            remotedc = Datacube(config=datacube_config)
        else:
            remotedc = Datacube()

        return_data = {}
        data = xr.Dataset()

        if source_prod != '':
            # find dataset to get metadata
            fd_query = {
                'time': (acq_min, acq_max),
                'x': (input_lon, input_lon + window_size / 100000),
                'y': (input_lat, input_lat + window_size / 100000),
            }
            sample_fd_ds = remotedc.find_datasets(product=source_prod,
                                                  group_by='solar_day',
                                                  **fd_query)

            if (len(sample_fd_ds)) > 0:
                # decidce pixel size for output data
                pixel_x, pixel_y = get_pixel_size(sample_fd_ds[0],
                                                  source_band_list)

                log.info('Output pixel size for product {}: x={}, y={}'.format(
                    source_prod, pixel_x, pixel_y))

                # get target epsg from metadata
                target_epsg = get_epsg(sample_fd_ds[0])
                log.info('CRS for product {}: {}'.format(
                    source_prod, target_epsg))

                x1, y1, x2, y2 = setQueryExtent(target_epsg, input_lon,
                                                input_lat, window_size)

                query = {
                    'time': (acq_min, acq_max),
                    'x': (x1, x2),
                    'y': (y1, y2),
                    'crs': target_epsg,
                    'output_crs': target_epsg,
                    'resolution': (-pixel_y, pixel_x),
                    'measurements': source_band_list
                }

                if 's2' in source_prod:
                    data = remotedc.load(product=source_prod,
                                         group_by='solar_day',
                                         **query)
                else:
                    data = remotedc.load(product=source_prod,
                                         align=(pixel_x / 2.0, pixel_y / 2.0),
                                         group_by='solar_day',
                                         **query)
                # remove cloud and nodta
                data = remove_cloud_nodata(source_prod, data, mask_band)

                if no_partial_scenes:
                    # calculate valid data percentage
                    data = only_return_whole_scene(data)

            return_data = {
                source_prod: {
                    'data': data,
                    'mask_band': mask_band,
                    'find_list': sample_fd_ds
                }
            }

    return return_data
Ejemplo n.º 13
0
def FindOutHowFullTheDamIs(shapes, crs):
    """
    This is where the code processing is actually done. This code takes in a polygon, and the
    shapefile's crs and performs a polygon drill into the wofs_albers product. The resulting
    xarray, which contains the water classified pixels for that polygon over every available
    timestep, is used to calculate the percentage of the water body that is wet at each time step.
    The outputs are written to a csv file named using the polygon ID.

    Inputs:
    shapes - polygon to be interrogated
    crs - crs of the shapefile

    Outputs:
    True or False - False if something unexpected happened, so the function can be run again.
    a csv file on disk is appended for every valid polygon.
    """
    dc = Datacube(app='Polygon drill')
    first_geometry = shapes['geometry']
    if 'ID' in shapes['properties'].keys():
        polyName = shapes['properties']['ID']
    else:
        polyName = shapes['properties']['FID']

    strPolyName = str(polyName).zfill(6)
    fpath = os.path.join(output_dir, f'{strPolyName[0:4]}/{strPolyName}.csv')

    # start_date = get_last_date(fpath)
    start_date = '2021-05-01'
    if start_date is None:
        time_period = ('2021-03-01', current_time.strftime('%Y-%m-%d'))
        # print(f'There is no csv for {strPolyName}')
        # return 1
    else:
        time_period = ('2021-03-01', current_time.strftime('%Y-%m-%d'))

        geom = geometry.Geometry(first_geometry, crs=crs)

        ## Set up the query, and load in all of the WOFS layers
        query = {'geopolygon': geom, 'time': time_period}
        #         WOFL = dc.load(product='wofs_albers', **query)
        WOFL = dc.load(product='wofs_albers',
                       group_by='solar_day',
                       fuse_func=wofls_fuser,
                       **query)
        if len(WOFL.attrs) == 0:
            print(f'There is no new data for {strPolyName}')
            return 2
        # Make a mask based on the polygon (to remove extra data outside of the polygon)
        mask = rasterio.features.geometry_mask(
            [geom.to_crs(WOFL.geobox.crs) for geoms in [geom]],
            out_shape=WOFL.geobox.shape,
            transform=WOFL.geobox.affine,
            all_touched=False,
            invert=True)
        wofl_masked = WOFL.water.where(mask)
        ## Work out how full the dam is at every time step
        DamCapacityPc = []
        DamCapacityCt = []
        LSA_WetPc = []
        DryObserved = []
        InvalidObservations = []
        for ix, times in enumerate(WOFL.time):

            # Grab the data for our timestep
            AllTheBitFlags = wofl_masked.isel(time=ix)

            # Find all the wet/dry pixels for that timestep
            LSA_Wet = AllTheBitFlags.where(
                AllTheBitFlags == 136).count().item()
            LSA_Dry = AllTheBitFlags.where(AllTheBitFlags == 8).count().item()
            WetPixels = AllTheBitFlags.where(
                AllTheBitFlags == 128).count().item() + LSA_Wet
            DryPixels = AllTheBitFlags.where(
                AllTheBitFlags == 0).count().item() + LSA_Dry

            # Apply the mask and count the number of observations
            MaskedAll = AllTheBitFlags.count().item()
            # Turn our counts into percents
            try:
                WaterPercent = WetPixels / MaskedAll * 100
                DryPercent = DryPixels / MaskedAll * 100
                UnknownPercent = (MaskedAll -
                                  (WetPixels + DryPixels)) / MaskedAll * 100
                LSA_WetPercent = LSA_Wet / MaskedAll * 100
            except ZeroDivisionError:
                WaterPercent = 0.0
                DryPercent = 0.0
                UnknownPercent = 100.0
                LSA_WetPercent = 0.0
            # Append the percentages to a list for each timestep
            DamCapacityPc.append(WaterPercent)
            InvalidObservations.append(UnknownPercent)
            DryObserved.append(DryPercent)
            DamCapacityCt.append(WetPixels)
            LSA_WetPc.append(LSA_WetPercent)

        ## Filter out timesteps with less than 90% valid observations
        try:
            ValidMask = [
                i for i, x in enumerate(InvalidObservations) if x < 10
            ]
            if len(ValidMask) > 0:
                ValidObs = WOFL.time[ValidMask].dropna(dim='time')
                ValidCapacityPc = [DamCapacityPc[i] for i in ValidMask]
                ValidCapacityCt = [DamCapacityCt[i] for i in ValidMask]
                ValidLSApc = [LSA_WetPc[i] for i in ValidMask]
                ValidObs = ValidObs.to_dataframe()
                if 'spatial_ref' in ValidObs.columns:
                    ValidObs = ValidObs.drop(columns=['spatial_ref'])

                DateList = ValidObs.to_csv(
                    None,
                    header=False,
                    index=False,
                    date_format="%Y-%m-%dT%H:%M:%SZ").split('\n')
                rows = zip(DateList, ValidCapacityPc, ValidCapacityCt,
                           ValidLSApc)

                if DateList:
                    strPolyName = str(polyName).zfill(6)
                    fpath = os.path.join(
                        output_dir, f'{strPolyName[0:4]}/{strPolyName}.csv')
                    os.makedirs(os.path.dirname(fpath), exist_ok=True)
                    with open(fpath, 'w') as f:
                        writer = csv.writer(f)
                        Headings = [
                            'Observation Date', 'Wet pixel percentage',
                            'Wet pixel count (n = {0})'.format(MaskedAll),
                            'LSA Wet Pixel Pct'
                        ]
                        writer.writerow(Headings)
                        for row in rows:
                            writer.writerow(row)
            else:
                print(f'{polyName} has no new good (90percent) valid data')
            return 1
        except:
            print(f'This polygon isn\'t working...: {polyName}')
            return 3
Ejemplo n.º 14
0
def post_processing(
    predicted: xr.Dataset, urls: Dict[str, Any]
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
    """
    Run the delayed post_processing functions, then create a lazy
    xr.Dataset to satisfy odc-stats
    """
    dc = Datacube(app="whatever")

    # grab predictions and proba for post process filtering
    predict = predicted.Predictions
    proba = predicted.Probabilities
    proba = proba.where(predict == 1, 100 - proba)  # crop proba only

    # ------image seg and filtering -------------
    # write out ndvi for image seg
    ndvi = assign_crs(predicted[["NDVI_S1", "NDVI_S2"]], crs=predicted.geobox.crs)

    # call function with dask delayed
    filtered = image_segmentation(ndvi, predict)

    # convert delayed object to dask array
    filtered = dask.array.from_delayed(
        filtered.squeeze(), shape=predict.shape, dtype=np.uint8
    )

    # convert dask array to xr.Datarray
    filtered = xr.DataArray(filtered, coords=predict.coords, attrs=predict.attrs)

    # --Post process masking----------------------------------------

    # merge back together for masking
    ds = xr.Dataset({"mask": predict, "prob": proba, "filtered": filtered})

    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file(urls["aez"])
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    mask = mask.chunk({})
    ds = ds.where(mask, 0)

    # mask with WDPA
    wdpa = rio_slurp_xarray(urls["wdpa"], gbox=predicted.geobox)
    wdpa = wdpa.chunk({})
    wdpa = wdpa.astype(bool)
    ds = ds.where(~wdpa, 0)

    # mask with WOFS
    wofs=dc.load(product='wofs_ls_summary_annual',
                 like=predicted.geobox,
                 dask_chunks={},
                 time=('2019'))
    wofs=wofs.frequency > 0.20 # threshold
    ds=ds.where(~wofs, 0)

    # mask steep slopes
    slope = rio_slurp_xarray(urls["slope"], gbox=predicted.geobox)
    slope = slope.chunk({})
    slope = slope > 50
    ds = ds.where(~slope, 0)

    # mask where the elevation is above 3600m
    elevation = dc.load(product="dem_srtm", like=predicted.geobox, dask_chunks={})
    elevation = elevation.elevation > 3600  # threshold
    ds = ds.where(~elevation.squeeze(), 0)

    return ds.squeeze()