def post_processing(predicted):
    """
    filter prediction results with post processing filters.
    
    Simplified from production code to skip
    segmentation, probability, and mode calcs

    """

    dc = Datacube(app='whatever')

    predict = predicted.Predictions

    #--Post process masking---------------------------------------------------------------
    #print("  masking with AEZ,WDPA,WOfS,slope & elevation")

    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file('data/Sahel.geojson')
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    predict = predict.where(mask, 0)

    # mask with WDPA
    #     url_wdpa="s3://deafrica-input-datasets/protected_areas/WDPA_southern.tif"
    #     wdpa=rio_slurp_xarray(url_wdpa, gbox=predicted.geobox)
    #     wdpa = wdpa.astype(bool)
    #     predict = predict.where(~wdpa, 0)

    #mask with WOFS
    wofs = dc.load(product='wofs_ls_summary_annual',
                   like=predicted.geobox,
                   time=('2019'))
    wofs = wofs.frequency > 0.2  # threshold
    predict = predict.where(~wofs, 0)

    #mask steep slopes
    url_slope = "https://deafrica-input-datasets.s3.af-south-1.amazonaws.com/srtm_dem/srtm_africa_slope.tif"
    slope = rio_slurp_xarray(url_slope, gbox=predicted.geobox)
    slope = slope > 50
    predict = predict.where(~slope, 0)

    #mask where the elevation is above 3600m
    elevation = dc.load(product='dem_srtm', like=predicted.geobox)
    elevation = elevation.elevation > 3600  # threshold
    predict = predict.where(~elevation.squeeze(), 0)

    #set dtype
    predict = predict.astype(np.int8)

    return predict
Example #2
0
def post_processing(
    predicted: xr.Dataset,

) -> xr.DataArray:
    """
    filter prediction results with post processing filters.
    :param predicted: The prediction results

    """
    
    dc = Datacube(app='whatever')
 
    #grab predictions and proba for post process filtering
    predict=predicted.Predictions

    
    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file('data/Western.geojson')
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    predict = predict.where(mask,0)
    
    # mask with WDPA
    url_wdpa="s3://deafrica-input-datasets/protected_areas/WDPA_western.tif"
    wdpa=rio_slurp_xarray(url_wdpa, gbox=predicted.geobox)
    wdpa = wdpa.astype(bool)
    predict = predict.where(~wdpa, 0)
    
    #mask with WOFS
    wofs=dc.load(product='ga_ls8c_wofs_2_summary',like=predicted.geobox)
    wofs=wofs.frequency > 0.2 # threshold
    predict=predict.where(~wofs, 0)

    #mask steep slopes
    url_slope="https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif"
    slope=rio_slurp_xarray(url_slope, gbox=predicted.geobox)
    slope=slope > 35
    predict=predict.where(~slope, 0)

    #mask where the elevation is above 3600m
    elevation=dc.load(product='dem_srtm', like=predicted.geobox)
    elevation=elevation.elevation > 3600 # threshold
    predict=predict.where(~elevation.squeeze(), 0)
    
    #set dtype
    predict=predict.astype(np.int8)

    return predict
Example #3
0
def post_processing(predicted):
    """
    filter prediction results with post processing filters.
    :param predicted: The prediction results

    """

    dc = Datacube(app='whatever')

    # grab predictions and proba for post process filtering
    predict = predicted.Predictions
    #     proba = predicted.Probabilities
    #     proba = proba.where(predict == 1, 100 - proba)  # crop proba only

    #     #------image seg and filtering -------------
    #     # write out ndvi for image seg
    #     ndvi = assign_crs(predicted[["NDVI_S1", "NDVI_S2"]],
    #                       crs=predicted.geobox.crs)

    #     # call function with dask delayed
    #     filtered = image_segmentation(ndvi, predict)

    #     # convert delayed object to dask array
    #     filtered = dask.array.from_delayed(filtered.squeeze(),
    #                                        shape=predict.shape,
    #                                        dtype=np.int8)

    #     # convert dask array to xr.Datarray
    #     filtered = xr.DataArray(filtered,
    #                             coords=predict.coords,
    #                             attrs=predict.attrs)

    # --Post process masking------------------------------------------------

    # merge back together for masking
    ds = xr.Dataset({"mask":
                     predict})  #, "prob": proba, "filtered": filtered})

    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file(
        'https://github.com/digitalearthafrica/crop-mask/blob/main/testing/eastern_cropmask/data/Eastern.geojson?raw=true'
    )
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    mask = mask.chunk({})
    ds = ds.where(mask, 0)

    # mask with WDPA
    wdpa = rio_slurp_xarray(
        "s3://deafrica-input-datasets/protected_areas/WDPA_eastern.tif",
        gbox=predicted.geobox)
    wdpa = wdpa.chunk({})
    wdpa = wdpa.astype(bool)
    ds = ds.where(~wdpa, 0)

    # mask with WOFS
    wofs = dc.load(product="ga_ls8c_wofs_2_summary",
                   like=predicted.geobox,
                   dask_chunks={})
    wofs = wofs.frequency > 0.2  # threshold
    ds = ds.where(~wofs, 0)

    # mask steep slopes
    slope = rio_slurp_xarray(
        'https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif',
        gbox=predicted.geobox)
    slope = slope.chunk({})
    slope = slope > 35
    ds = ds.where(~slope, 0)

    # mask where the elevation is above 3600m
    elevation = dc.load(product="dem_srtm",
                        like=predicted.geobox,
                        dask_chunks={})
    elevation = elevation.elevation > 3600  # threshold
    ds = ds.where(~elevation.squeeze(), 0)

    return ds.squeeze()
    def handle_draw(self, action, geo_json):
        nonlocal polygon_number

        # Execute behaviour based on what the user draws
        if geo_json['geometry']['type'] == 'Polygon':

            info.clear_output(wait=True)  # wait=True reduces flicker effect
            
            # Save geojson polygon to io temporary file to be rasterized later
            jsonData = json.dumps(geo_json)
            binaryData = jsonData.encode()
            io = BytesIO(binaryData)
            io.seek(0)
            
            # Read the polygon as a geopandas dataframe
            gdf = gpd.read_file(io)
            gdf.crs = "EPSG:4326"

            # Convert the drawn geometry to pixel coordinates
            xr_poly = xr_rasterize(gdf, ds.NDVI.isel(time=0), crs='EPSG:6933')

            # Construct a mask to only select pixels within the drawn polygon
            masked_ds = ds.NDVI.where(xr_poly)
            
            masked_ds_mean = masked_ds.mean(dim=['x', 'y'], skipna=True)
            colour = colour_list[polygon_number % len(colour_list)]

            # Add a layer to the map to make the most recently drawn polygon
            # the same colour as the line on the plot
            studyarea_map.add_layer(
                GeoJSON(
                    data=geo_json,
                    style={
                        'color': colour,
                        'opacity': 1,
                        'weight': 4.5,
                        'fillOpacity': 0.0
                    }
                )
            )

            # add new data to the plot
            xr.plot.plot(
                masked_ds_mean,
                marker='*',
                color=colour,
                ax=ax
            )

            # reset titles back to custom
            ax.set_title("Average NDVI from Sentinel-2")
            ax.set_xlabel("Date")
            ax.set_ylabel("NDVI")

            # refresh display
            fig_display.clear_output(wait=True)  # wait=True reduces flicker effect
            with fig_display:
                display(fig)
                
            with info:
                print("Plot status: polygon sucessfully added to plot.")

            # Iterate the polygon number before drawing another polygon
            polygon_number = polygon_number + 1

        else:
            info.clear_output(wait=True)
            with info:
                print("Plot status: this drawing tool is not currently "
                      "supported. Please use the polygon tool.")
Example #5
0
def post_processing(
    predicted: xr.Dataset, urls: Dict[str, Any]
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
    """
    Run the delayed post_processing functions, then create a lazy
    xr.Dataset to satisfy odc-stats
    """
    dc = Datacube(app="whatever")

    # grab predictions and proba for post process filtering
    predict = predicted.Predictions
    proba = predicted.Probabilities
    proba = proba.where(predict == 1, 100 - proba)  # crop proba only

    # ------image seg and filtering -------------
    # write out ndvi for image seg
    ndvi = assign_crs(predicted[["NDVI_S1", "NDVI_S2"]], crs=predicted.geobox.crs)

    # call function with dask delayed
    filtered = image_segmentation(ndvi, predict)

    # convert delayed object to dask array
    filtered = dask.array.from_delayed(
        filtered.squeeze(), shape=predict.shape, dtype=np.uint8
    )

    # convert dask array to xr.Datarray
    filtered = xr.DataArray(filtered, coords=predict.coords, attrs=predict.attrs)

    # --Post process masking----------------------------------------

    # merge back together for masking
    ds = xr.Dataset({"mask": predict, "prob": proba, "filtered": filtered})

    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file(urls["aez"])
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    mask = mask.chunk({})
    ds = ds.where(mask, 0)

    # mask with WDPA
    wdpa = rio_slurp_xarray(urls["wdpa"], gbox=predicted.geobox)
    wdpa = wdpa.chunk({})
    wdpa = wdpa.astype(bool)
    ds = ds.where(~wdpa, 0)

    # mask with WOFS
    wofs=dc.load(product='wofs_ls_summary_annual',
                 like=predicted.geobox,
                 dask_chunks={},
                 time=('2019'))
    wofs=wofs.frequency > 0.20 # threshold
    ds=ds.where(~wofs, 0)

    # mask steep slopes
    slope = rio_slurp_xarray(urls["slope"], gbox=predicted.geobox)
    slope = slope.chunk({})
    slope = slope > 50
    ds = ds.where(~slope, 0)

    # mask where the elevation is above 3600m
    elevation = dc.load(product="dem_srtm", like=predicted.geobox, dask_chunks={})
    elevation = elevation.elevation > 3600  # threshold
    ds = ds.where(~elevation.squeeze(), 0)

    return ds.squeeze()