Ejemplo n.º 1
0
def complete_gm_mads(era_base_ds: xr.Dataset, geobox: GeoBox,
                     era: str) -> xr.Dataset:
    """
    merge the geomedian and rainfall chirps data together
    :param era_base_ds:
    :param geobox:
    :param era:
    :return:
    """
    # TODO: this is half year data, require integration tests
    # TODO: load this data once use dask publish (?)
    gm_mads = assign_crs(calculate_indices(era_base_ds))

    rainfall = assign_crs(xr.open_rasterio(
        FeaturePathConfig.rainfall_path[era]),
                          crs="epsg:4326")

    rainfall = chirp_clip(gm_mads, rainfall)

    rainfall = (xr_reproject(rainfall, geobox,
                             "bilinear").drop(["band",
                                               "spatial_ref"]).squeeze())
    gm_mads["rain"] = rainfall

    return gm_mads.rename(
        dict((var_name, str(var_name) + era.upper())
             for var_name in gm_mads.data_vars))
Ejemplo n.º 2
0
    def fun(ds, era):
        # six-month geomedians
        gm_mads = xr_geomedian(ds)
        gm_mads = calculate_indices(
            gm_mads,
            index=["NDVI", "LAI", "MNDWI"],
            drop=False,
            normalise=False,
            collection="s2",
        )

        # rainfall climatology
        if era == "_S1":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )
        if era == "_S2":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )

        chirps = xr_reproject(chirps, ds.geobox, "bilinear")
        gm_mads["rain"] = chirps

        for band in gm_mads.data_vars:
            gm_mads = gm_mads.rename({band: band + era})

        return gm_mads
Ejemplo n.º 3
0
 def fun(ds, era):
     #geomedian and tmads
     gm_mads = xr_geomedian_tmad(ds)
     gm_mads = calculate_indices(gm_mads,
                            index=['NDVI','LAI','MNDWI'],
                            drop=False,
                            normalise=False,
                            collection='s2')
     
     gm_mads['sdev'] = -np.log(gm_mads['sdev'])
     gm_mads['bcdev'] = -np.log(gm_mads['bcdev'])
     gm_mads['edev'] = -np.log(gm_mads['edev'])
     
     #rainfall climatology
     if era == '_S1':
         chirps = assign_crs(xr.open_rasterio('/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc'),  crs='epsg:4326')
     if era == '_S2':
         chirps = assign_crs(xr.open_rasterio('/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc'),  crs='epsg:4326')
     
     chirps = xr_reproject(chirps,ds.geobox,"bilinear")
     gm_mads['rain'] = chirps
     
     for band in gm_mads.data_vars:
         gm_mads = gm_mads.rename({band:band+era})
     
     return gm_mads
Ejemplo n.º 4
0
def annual_gm_mads_evi_training(ds):
    dc = datacube.Datacube(app='training')
    
    # grab gm+tmads
    gm_mads=dc.load(product='ga_s2_gm',time='2019',like=ds.geobox,
                   measurements=['red', 'blue', 'green', 'nir',
                                 'swir_1', 'swir_2', 'red_edge_1',
                                 'red_edge_2', 'red_edge_3', 'SMAD',
                                 'BCMAD','EMAD'])
    
    gm_mads['SMAD'] = -np.log(gm_mads['SMAD'])
    gm_mads['BCMAD'] = -np.log(gm_mads['BCMAD'])
    gm_mads['EMAD'] = -np.log(gm_mads['EMAD']/10000)
    
    #calculate band indices on gm
    gm_mads = calculate_indices(gm_mads,
                               index=['EVI','LAI','MNDWI'],
                               drop=False,
                               collection='s2')
    
    #normalise spectral GM bands 0-1
    for band in gm_mads.data_vars:
        if band not in ['SMAD', 'BCMAD','EMAD', 'EVI', 'LAI', 'MNDWI']:
            gm_mads[band] = gm_mads[band] / 10000
    
    #calculate EVI on annual timeseries
    evi = calculate_indices(ds,index=['EVI'], drop=True, normalise=True, collection='s2')
    
    # EVI stats 
    gm_mads['evi_std'] = evi.EVI.std(dim='time')
    gm_mads['evi_10'] = evi.EVI.quantile(0.1, dim='time')
    gm_mads['evi_25'] = evi.EVI.quantile(0.25, dim='time')
    gm_mads['evi_75'] = evi.EVI.quantile(0.75, dim='time')
    gm_mads['evi_90'] = evi.EVI.quantile(0.9, dim='time')
    gm_mads['evi_range'] = gm_mads['evi_90'] - gm_mads['evi_10']
    
    #rainfall climatology
    chirps_S1 = xr_reproject(assign_crs(xr.open_rasterio('/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc'),
                                        crs='epsg:4326'), ds.geobox,"bilinear")
    
    chirps_S2 = xr_reproject(assign_crs(xr.open_rasterio('/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc'), 
                                        crs='epsg:4326'), ds.geobox,"bilinear")
        
    gm_mads['rain_S1'] = chirps_S1
    gm_mads['rain_S2'] = chirps_S2
    
    #slope
    url_slope = "https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif"
    slope = rio_slurp_xarray(url_slope, gbox=ds.geobox)
    slope = slope.to_dataset(name='slope')#.chunk({'x':2000,'y':2000})
    
    result = xr.merge([gm_mads,slope],compat='override')

    return result.squeeze()
Ejemplo n.º 5
0
def merge_tifs_into_ds(
    root_fld: str,
    tifs: List[str],
    rename_dict: Optional[Dict] = None,
    tifs_min_num=8,
) -> xr.Dataset:
    """
    Will be replaced with dc.load(gm_6month) once they've been produced.
    
    use os.walk to get the all files under a folder, it just merge the half year tifs.
    We need combine two half-year tifs ds and add (calculated indices, rainfall, and slope)
    @param tifs: tifs with the bands
    @param root_fld: the parent folder for the sub_fld
    @param tifs_min_num: geo-median tifs is 16 a tile idx
    @param rename_dict: we can put the rename dictionary here
    @return:
    """
    assert len(tifs) > tifs_min_num
    cache = []
    for tif in tifs:
        if tif.endswith(".tif"):
            band_name = re.search(r"_([A-Za-z0-9]+).tif", tif).groups()[0]
            if band_name in ["rgba", "COUNT"]:
                continue

            band_array = assign_crs(xr.open_rasterio(osp.join(
                root_fld, tif)).squeeze().to_dataset(name=band_name),
                                    crs='epsg:6933')
            cache.append(band_array)
    # clean up output
    output = xr.merge(cache).squeeze()
    output = output.drop(["band"])

    return output.rename(rename_dict) if rename_dict else output
Ejemplo n.º 6
0
def chirp_clip(ds: xr.Dataset, chirps: xr.DataArray) -> xr.DataArray:
    """
     fill na with mean on chirps data
    :param ds: geomedian collected with certain geobox
    :param chirps: rainfall data
    :return: chirps data without na
    """
    # TODO: test with dummy ds and chirps
    # Clip CHIRPS to ~ S2 tile boundaries so we can handle NaNs local to S2 tile
    xmin, xmax = ds.x.values[0], ds.x.values[-1]
    ymin, ymax = ds.y.values[0], ds.y.values[-1]
    inProj = Proj("epsg:6933")
    outProj = Proj("epsg:4326")
    xmin, ymin = transform(inProj, outProj, xmin, ymin)
    xmax, ymax = transform(inProj, outProj, xmax, ymax)

    # create lat/lon indexing slices - buffer S2 bbox by 0.05deg
    # Todo: xmin < 0 and xmax < 0,  x_slice = [], unit tests
    if (xmin < 0) & (xmax < 0):
        x_slice = list(np.arange(xmin + 0.05, xmax - 0.05, -0.05))
    else:
        x_slice = list(np.arange(xmax - 0.05, xmin + 0.05, 0.05))

    if (ymin < 0) & (ymax < 0):
        y_slice = list(np.arange(ymin + 0.05, ymax - 0.05, -0.05))
    else:
        y_slice = list(np.arange(ymin - 0.05, ymax + 0.05, 0.05))

    # index global chirps using buffered s2 tile bbox
    chirps = assign_crs(chirps.sel(x=y_slice, y=x_slice, method="nearest"))

    # fill any NaNs in CHIRPS with local (s2-tile bbox) mean
    return chirps.fillna(chirps.mean())
Ejemplo n.º 7
0
def xr_geomedian_tmad_new(ds, **kw):
    """
    Same as other one but uses reshape_yxbt instead of
    reshape_for_geomedian
    """

    import hdstats

    def gm_tmad(arr, **kw):
        """
        arr: a high dimensional numpy array where the last dimension will be reduced. 
    
        returns: a numpy array with one less dimension than input.
        """
        gm = hdstats.nangeomedian_pcm(arr, **kw)
        nt = kw.pop('num_threads', None)
        emad = hdstats.emad_pcm(arr, gm, num_threads=nt)[:, :, np.newaxis]
        smad = hdstats.smad_pcm(arr, gm, num_threads=nt)[:, :, np.newaxis]
        bcmad = hdstats.bcmad_pcm(arr, gm, num_threads=nt)[:, :, np.newaxis]
        return np.concatenate([gm, emad, smad, bcmad], axis=-1)

    def norm_input(ds):
        if isinstance(ds, xr.Dataset):
            xx = reshape_yxbt(ds, yx_chunks=500)
            return ds, xx, xx.data

    kw.setdefault('nocheck', False)
    kw.setdefault('num_threads', 1)
    kw.setdefault('eps', 1e-6)

    ds, xx, xx_data = norm_input(ds)
    is_dask = dask.is_dask_collection(xx_data)

    if is_dask:
        data = da.map_blocks(lambda x: gm_tmad(x, **kw),
                             xx_data,
                             name=randomize('geomedian'),
                             dtype=xx_data.dtype,
                             chunks=xx_data.chunks[:-2] +
                             (xx_data.chunks[-2][0] + 3, ),
                             drop_axis=3)

    dims = xx.dims[:-1]
    cc = {k: xx.coords[k] for k in dims}
    cc[dims[-1]] = np.hstack(
        [xx.coords[dims[-1]].values, ['edev', 'sdev', 'bcdev']])
    xx_out = xr.DataArray(data, dims=dims, coords=cc)

    if ds is None:
        xx_out.attrs.update(xx.attrs)
        return xx_out

    ds_out = xx_out.to_dataset(dim='band')
    for b in ds.data_vars.keys():
        src, dst = ds[b], ds_out[b]
        dst.attrs.update(src.attrs)

    return assign_crs(ds_out, crs=ds.geobox.crs)
Ejemplo n.º 8
0
def add_chirps(
    urls: Dict[Any, Any],
    ds: xr.Dataset,
    era: str,
    training: bool = True,
    dask_chunks: Dict[Any, Any] = {
        "x": "auto",
        "y": "auto"
    },
) -> Optional[xr.Dataset]:
    # load rainfall climatology
    if era == "_S1":
        chirps = rio_slurp_xarray(urls["chirps"][0])
    if era == "_S2":
        chirps = rio_slurp_xarray(urls["chirps"][1])

    if chirps.size >= 2:
        if training:
            chirps = xr_reproject(chirps, ds.geobox, "bilinear")
            ds["rain"] = chirps
        else:
            # Clip CHIRPS to ~ S2 tile boundaries so we can handle NaNs local to S2 tile
            xmin, xmax = ds.x.values[0], ds.x.values[-1]
            ymin, ymax = ds.y.values[0], ds.y.values[-1]
            inProj = Proj("epsg:6933")
            outProj = Proj("epsg:4326")
            xmin, ymin = transform(inProj, outProj, xmin, ymin)
            xmax, ymax = transform(inProj, outProj, xmax, ymax)

            # create lat/lon indexing slices - buffer S2 bbox by 0.05deg
            if (xmin < 0) & (xmax < 0):
                x_slice = list(np.arange(xmin + 0.05, xmax - 0.05, -0.05))
            else:
                x_slice = list(np.arange(xmax - 0.05, xmin + 0.05, 0.05))

            y_slice = list(np.arange(ymin - 0.05, ymax + 0.1, 0.05))

            # index global chirps using buffered s2 tile bbox
            chirps = assign_crs(
                chirps.sel(longitude=y_slice,
                           latitude=x_slice,
                           method="nearest"))
            # fill any NaNs in CHIRPS with local (s2-tile bbox) mean
            chirps = chirps.fillna(chirps.mean())
            chirps = xr_reproject(chirps, ds.geobox, "bilinear")
            chirps = chirps.chunk(dask_chunks)
            ds["rain"] = chirps

        # rename bands to include era
        for band in ds.data_vars:
            ds = ds.rename({band: band + era})

        return ds

    return None
Ejemplo n.º 9
0
def test_assign_crs(odc_style_xr_dataset):
    xx = odc_style_xr_dataset
    assert xx.geobox is not None

    xx_nocrs = remove_crs(xx)
    assert xx_nocrs.geobox is None
    yy = assign_crs(xx_nocrs, 'epsg:4326')
    assert xx_nocrs.geobox is None  # verify source is not modified in place
    assert yy.geobox.crs == 'epsg:4326'

    yy = assign_crs(xx_nocrs.B10, 'epsg:4326')
    assert yy.geobox.crs == 'epsg:4326'

    xx_xr_style_crs = xx_nocrs.copy()
    xx_xr_style_crs.attrs.update(crs='epsg:3857')
    yy = assign_crs(xx_xr_style_crs)
    assert yy.geobox.crs == 'epsg:3857'

    with pytest.raises(ValueError):
        assign_crs(xx_nocrs)
Ejemplo n.º 10
0
    def fun(ds, era):
        # normalise SR and edev bands
        for band in ds.data_vars:
            if band not in ["sdev", "bcdev"]:
                ds[band] = ds[band] / 10000

        gm_mads = calculate_indices(
            ds,
            index=["NDVI", "LAI", "MNDWI"],
            drop=False,
            normalise=False,
            collection="s2",
        )

        gm_mads["sdev"] = -np.log(gm_mads["sdev"])
        gm_mads["bcdev"] = -np.log(gm_mads["bcdev"])
        gm_mads["edev"] = -np.log(gm_mads["edev"])

        # rainfall climatology
        if era == "_S1":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )
        if era == "_S2":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )

        chirps = xr_reproject(chirps, ds.geobox, "bilinear")
        gm_mads["rain"] = chirps

        for band in gm_mads.data_vars:
            gm_mads = gm_mads.rename({band: band + era})

        return gm_mads
Ejemplo n.º 11
0
    def fun(ds, era):
        # geomedian and tmads
        # gm_mads = xr_geomedian_tmad(ds)
        gm_mads = xr_geomedian_tmad_new(ds).compute()
        gm_mads = calculate_indices(
            gm_mads,
            index=["NDVI", "LAI", "MNDWI"],
            drop=False,
            normalise=False,
            collection="s2",
        )

        gm_mads["sdev"] = -np.log(gm_mads["sdev"])
        gm_mads["bcdev"] = -np.log(gm_mads["bcdev"])
        gm_mads["edev"] = -np.log(gm_mads["edev"])
        gm_mads = gm_mads.chunk({"x": 2000, "y": 2000})

        # rainfall climatology
        if era == "_S1":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )
        if era == "_S2":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )

        chirps = xr_reproject(chirps, ds.geobox, "bilinear")
        chirps = chirps.chunk({"x": 2000, "y": 2000})
        gm_mads["rain"] = chirps

        for band in gm_mads.data_vars:
            gm_mads = gm_mads.rename({band: band + era})

        return gm_mads
Ejemplo n.º 12
0
def add_chirps(ds, era, training=True, dask_chunks={'x': 'auto', 'y': 'auto'}):

    # load rainfall climatology
    if era == "_S1":
        chirps = rio_slurp_xarray(
            "s3://deafrica-input-datasets/rainfall/CHPclim_jan_jun_cumulative_rainfall.tif"
        )
    if era == "_S2":
        chirps = rio_slurp_xarray(
            "s3://deafrica-input-datasets/rainfall/CHPclim_jul_dec_cumulative_rainfall.tif"
        )

    if training:
        chirps = xr_reproject(chirps, ds.geobox, "bilinear")
        ds["rain"] = chirps

    else:
        # Clip CHIRPS to ~ S2 tile boundaries so we can handle NaNs local to S2 tile
        xmin, xmax = ds.x.values[0], ds.x.values[-1]
        ymin, ymax = ds.y.values[0], ds.y.values[-1]
        inProj = Proj("epsg:6933")
        outProj = Proj("epsg:4326")
        xmin, ymin = transform(inProj, outProj, xmin, ymin)
        xmax, ymax = transform(inProj, outProj, xmax, ymax)

        # create lat/lon indexing slices - buffer S2 bbox by 0.05deg
        if (xmin < 0) & (xmax < 0):
            x_slice = list(np.arange(xmin + 0.05, xmax - 0.05, -0.05))
        else:
            x_slice = list(np.arange(xmax - 0.05, xmin + 0.05, 0.05))

        if (ymin < 0) & (ymax < 0):
            y_slice = list(np.arange(ymin + 0.05, ymax - 0.05, -0.05))
        else:
            y_slice = list(np.arange(ymin - 0.05, ymax + 0.05, 0.05))

        # index global chirps using buffered s2 tile bbox
        chirps = assign_crs(
            chirps.sel(longitude=y_slice, latitude=x_slice, method="nearest"))

        # fill any NaNs in CHIRPS with local (s2-tile bbox) mean
        chirps = chirps.fillna(chirps.mean())
        chirps = xr_reproject(chirps, ds.geobox, "bilinear")
        chirps = chirps.chunk(dask_chunks)
        ds["rain"] = chirps

    #rename bands to include era
    for band in ds.data_vars:
        ds = ds.rename({band: band + era})

    return ds
Ejemplo n.º 13
0
def merge_tifs_into_ds(
    root_fld: str,
    tifs: List[str],
    rename_dict: Optional[Dict] = None,
    tifs_min_num=8,
) -> xr.Dataset:
    """
    use os.walk to get the all files under a folder, it just merge the half year tifs.
    We need combine two half-year tifs ds and add (calculated indices, rainfall, and slope)
    :param tifs: tifs with the bands
    :param root_fld: the parent folder for the sub_fld
    :param tifs_min_num: geo-median tifs is 16 a tile idx
    :param rename_dict: we can put the rename dictionary here
    :return:
    """
    # TODO: create dummy datasets to test merge tis
    assert len(tifs) > tifs_min_num
    cache = []
    for tif in tifs:
        if tif.endswith(".tif"):
            band_name = re.search(r"_([A-Za-z0-9]+).tif", tif).groups()[0]
            if band_name in ["rgba", "COUNT"]:
                continue

            band_array = assign_crs(
                xr.open_rasterio(osp.join(
                    root_fld, tif)).squeeze().to_dataset(name=band_name),
                crs="epsg:6933",
            )
            cache.append(band_array)
    # clean up output
    output = xr.merge(cache).squeeze()
    output.attrs["crs"] = "epsg:{}".format(output["spatial_ref"].values)
    output.attrs["tile-task-str"] = "/".join(root_fld.split("/")[-3:])
    output = output.drop(["spatial_ref", "band"])
    return output.rename(rename_dict) if rename_dict else output
Ejemplo n.º 14
0
    def fun(ds, era):
        # normalise SR and edev bands
        for band in ds.data_vars:
            if band not in ["sdev", "bcdev"]:
                ds[band] = ds[band] / 10000

        gm_mads = calculate_indices(
            ds,
            index=["NDVI", "LAI", "MNDWI"],
            drop=False,
            normalise=False,
            collection="s2",
        )

        gm_mads["sdev"] = -np.log(gm_mads["sdev"])
        gm_mads["bcdev"] = -np.log(gm_mads["bcdev"])
        gm_mads["edev"] = -np.log(gm_mads["edev"])

        # rainfall climatology
        if era == "_S1":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )
        if era == "_S2":
            chirps = assign_crs(
                xr.open_rasterio(
                    "/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc"
                ),
                crs="epsg:4326",
            )

        # Clip CHIRPS to ~ S2 tile boundaries so we can handle NaNs local to S2 tile
        xmin, xmax = ds.x.values[0], ds.x.values[-1]
        ymin, ymax = ds.y.values[0], ds.y.values[-1]
        inProj = Proj("epsg:6933")
        outProj = Proj("epsg:4326")
        xmin, ymin = transform(inProj, outProj, xmin, ymin)
        xmax, ymax = transform(inProj, outProj, xmax, ymax)

        # create lat/lon indexing slices - buffer S2 bbox by 0.05deg
        if (xmin < 0) & (xmax < 0):
            x_slice = list(np.arange(xmin + 0.05, xmax - 0.05, -0.05))
        else:
            x_slice = list(np.arange(xmax - 0.05, xmin + 0.05, 0.05))

        if (ymin < 0) & (ymax < 0):
            y_slice = list(np.arange(ymin + 0.05, ymax - 0.05, -0.05))
        else:
            y_slice = list(np.arange(ymin - 0.05, ymax + 0.05, 0.05))

        # index global chirps using buffered s2 tile bbox
        chirps = assign_crs(chirps.sel(x=y_slice, y=x_slice, method="nearest"))

        # fill any NaNs in CHIRPS with local (s2-tile bbox) mean
        chirps = chirps.fillna(chirps.mean())
        chirps = xr_reproject(chirps, ds.geobox, "bilinear")
        gm_mads["rain"] = chirps

        for band in gm_mads.data_vars:
            gm_mads = gm_mads.rename({band: band + era})

        return gm_mads
Ejemplo n.º 15
0
        # Resample data temporally into time steps, and compute geomedians
        ds_geomedian = (
            ds.groupby(time_steps_var).apply(lambda ds_subset: xr_geomedian(
                ds_subset,
                num_threads=
                1,  # disable internal threading, dask will run several concurrently
                eps=0.2 * (1 / 10_000),  # 1/5 pixel value resolution
                nocheck=True))
        )  # disable some checks inside geomedian library that use too much ram

        print('\nGenerating geomedian composites and plotting '
              'filmstrips... (click the Dashboard link above for status)')
        ds_geomedian = ds_geomedian.compute()

        # Reset CRS that is lost during geomedian compositing
        ds_geomedian = assign_crs(ds_geomedian, crs=ds.geobox.crs)

        ############
        # Plotting #
        ############

        # Convert to array and extract vmin/vmax
        output_array = ds_geomedian[['red', 'green', 'blue']].to_array()
        percentiles = output_array.quantile(q=(0.02, 0.98)).values

        # Create the plot with one subplot more than timesteps in the
        # dataset. Figure width is set based on the number of subplots
        # and aspect ratio
        n_obs = output_array.sizes['timestep']
        ratio = output_array.sizes['x'] / output_array.sizes['y']
        fig, axes = plt.subplots(1,
Ejemplo n.º 16
0
def features(ds, era):
    #normalise SR and edev bands
    for band in ds.data_vars:
        if band not in ['sdev', 'bcdev']:
            ds[band] = ds[band] / 10000

    gm_mads = calculate_indices(ds,
                                index=['NDVI', 'LAI', 'MNDWI'],
                                drop=False,
                                normalise=False,
                                collection='s2')

    gm_mads['sdev'] = -np.log(gm_mads['sdev'])
    gm_mads['bcdev'] = -np.log(gm_mads['bcdev'])
    gm_mads['edev'] = -np.log(gm_mads['edev'])

    #rainfall climatology
    if era == '_S1':
        chirps = assign_crs(xr.open_rasterio(
            '/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc'
        ),
                            crs='epsg:4326')

    if era == '_S2':
        chirps = assign_crs(xr.open_rasterio(
            '/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc'
        ),
                            crs='epsg:4326')

    #Clip CHIRPS to ~ S2 tile boundaries so we can handle NaNs local to S2 tile
    xmin, xmax = ds.x.values[0], ds.x.values[-1]
    ymin, ymax = ds.y.values[0], ds.y.values[-1]
    inProj = Proj('epsg:6933')
    outProj = Proj('epsg:4326')
    xmin, ymin = transform(inProj, outProj, xmin, ymin)
    xmax, ymax = transform(inProj, outProj, xmax, ymax)

    #create lat/lon indexing slices - buffer S2 bbox by 0.05deg
    if (xmin < 0) & (xmax < 0):
        x_slice = list(np.arange(xmin + 0.05, xmax - 0.05, -0.05))
    else:
        x_slice = list(np.arange(xmax - 0.05, xmin + 0.05, 0.05))

    if (ymin < 0) & (ymax < 0):
        y_slice = list(np.arange(ymin + 0.05, ymax - 0.05, -0.05))
    else:
        y_slice = list(np.arange(ymin - 0.05, ymax + 0.05, 0.05))

    #index global chirps using buffered s2 tile bbox
    chirps = assign_crs(chirps.sel(x=y_slice, y=x_slice, method='nearest'))

    #fill any NaNs in CHIRPS with local (s2-tile bbox) mean
    chirps = chirps.fillna(chirps.mean())

    #reproject to match satellite data
    chirps = xr_reproject(chirps, ds.geobox, "bilinear")
    gm_mads['rain'] = chirps

    for band in gm_mads.data_vars:
        gm_mads = gm_mads.rename({band: band + era})

    return gm_mads
Ejemplo n.º 17
0
def download_cci_lc(year: str,
                    s3_dst: str,
                    workdir: str,
                    overwrite: bool = False):
    log = setup_logging()
    assets = {}

    cci_lc_version = get_version_from_year(year)
    name = f"{PRODUCT_NAME}_{year}_{cci_lc_version}"

    out_cog = URL(s3_dst) / year / f"{name}.tif"
    out_stac = URL(s3_dst) / year / f"{name}.stac-item.json"

    if s3_head_object(str(out_stac)) is not None and not overwrite:
        log.info(f"{out_stac} exists, skipping")
        return

    workdir = Path(workdir)
    if not workdir.exists():
        workdir.mkdir(parents=True, exist_ok=True)

    # Create a temporary directory to work with
    tmpdir = mkdtemp(prefix=str(f"{workdir}/"))
    log.info(f"Working on {year} in the path {tmpdir}")

    if s3_head_object(str(out_cog)) is None or overwrite:
        log.info(f"Downloading {year}")
        try:
            local_file = Path(tmpdir) / f"{name}.zip"
            if not local_file.exists():
                # Download the file
                c = cdsapi.Client()

                # We could also retrieve the object metadata from the CDS.
                # e.g. f = c.retrieve("series",{params}) | f.location = URL to download
                c.retrieve(
                    "satellite-land-cover",
                    {
                        "format": "zip",
                        "variable": "all",
                        "version": cci_lc_version,
                        "year": str(year),
                    },
                    local_file,
                )

                log.info(f"Downloaded file to {local_file}")
            else:
                log.info(
                    f"File {local_file} exists, continuing without downloading"
                )

            # Unzip the file
            log.info(f"Unzipping {local_file}")
            unzipped = None
            with zipfile.ZipFile(local_file, "r") as zip_ref:
                unzipped = local_file.parent / zip_ref.namelist()[0]
                zip_ref.extractall(tmpdir)

            # Process data
            ds = xr.open_dataset(unzipped)
            # Subset to Africa
            ulx, uly, lrx, lry = AFRICA_BBOX
            # Note: lats are upside down!
            ds_small = ds.sel(lat=slice(uly, lry), lon=slice(ulx, lrx))
            ds_small = assign_crs(ds_small, crs="epsg:4326")

            # Create cog (in memory - :mem: returns bytes object)
            mem_dst = write_cog(
                ds_small.lccs_class,
                ":mem:",
                nodata=0,
                overview_resampling="nearest",
            )

            # Write to s3
            s3_dump(mem_dst, str(out_cog), ACL="bucket-owner-full-control")
            log.info(f"File written to {out_cog}")

        except Exception:
            log.exception(f"Failed to process {name}")
            exit(1)
    else:
        log.info(f"{out_cog} exists, skipping")

    assets["classification"] = pystac.Asset(href=str(out_cog),
                                            roles=["data"],
                                            media_type=pystac.MediaType.COG)

    # Write STAC document
    source_doc = (
        "https://cds.climate.copernicus.eu/cdsapp#!/dataset/satellite-land-cover"
    )
    item = create_stac_item(
        str(out_cog),
        id=str(
            odc_uuid("Copernicus Land Cover", cci_lc_version,
                     [source_doc, name])),
        assets=assets,
        with_proj=True,
        properties={
            "odc:product": PRODUCT_NAME,
            "start_datetime": f"{year}-01-01T00:00:00Z",
            "end_datetime": f"{year}-12-31T23:59:59Z",
        },
    )
    item.add_links([
        pystac.Link(
            target=source_doc,
            title="Source",
            rel=pystac.RelType.DERIVED_FROM,
            media_type="text/html",
        )
    ])
    s3_dump(
        json.dumps(item.to_dict(), indent=2),
        str(out_stac),
        ContentType="application/json",
        ACL="bucket-owner-full-control",
    )
    log.info(f"STAC written to {out_stac}")
Ejemplo n.º 18
0
def gm_mads_evi_rainfall(ds):
    """
    6 monthly and annual 
    gm + mads
    evi stats (10, 50, 90 percentile, range, std)
    rainfall actual stats (min, mean, max, range, std) from monthly data
    rainfall clim stats (min, mean, max, range, std) from monthly data
    """
    dc = datacube.Datacube(app='training')
    ds = ds / 10000
    ds = ds.rename({'nir_1':'nir_wide', 'nir_2':'nir'})
    ds1 = ds.sel(time=slice('2019-01', '2019-06'))
    ds2 = ds.sel(time=slice('2019-07', '2019-12')) 
    
    chirps = []
    chpclim = []
    for m in range(1,13):
        chirps.append(xr_reproject(assign_crs(xr.open_rasterio(f'/g/data/CHIRPS/monthly_2019/chirps-v2.0.2019.{m:02d}.tif').squeeze().expand_dims({'time':[m]}), crs='epsg:4326'), 
                                   ds.geobox, "bilinear"))
        chpclim.append(rio_slurp_xarray(f'https://deafrica-data-dev.s3.amazonaws.com/product-dev/deafrica_chpclim_50n_50s_{m:02d}.tif', gbox=ds.geobox, 
                                        resapling='bilinear').expand_dims({'time':[m]}))
    
    chirps = xr.concat(chirps, dim='time')
    chpclim = xr.concat(chpclim, dim='time')
   
    def fun(ds, chirps, chpclim, era):
        ds = calculate_indices(ds,
                               index=['EVI'],
                               drop=False,
                               normalise=False,
                               collection='s2')        
        #geomedian and tmads
        gm_mads = xr_geomedian_tmad(ds)
        gm_mads = calculate_indices(gm_mads,
                               index=['EVI','NDVI','LAI','MNDWI'],
                               drop=False,
                               normalise=False,
                               collection='s2')
        
        gm_mads['sdev'] = -np.log(gm_mads['sdev'])
        gm_mads['bcdev'] = -np.log(gm_mads['bcdev'])
        gm_mads['edev'] = -np.log(gm_mads['edev'])
        
        # EVI stats 
        gm_mads['evi_10'] = ds.EVI.quantile(0.1, dim='time')
        gm_mads['evi_50'] = ds.EVI.quantile(0.5, dim='time')
        gm_mads['evi_90'] = ds.EVI.quantile(0.9, dim='time')
        gm_mads['evi_range'] = gm_mads['evi_90'] - gm_mads['evi_10']
        gm_mads['evi_std'] = ds.EVI.std(dim='time')

        # rainfall actual
        gm_mads['rain_min'] = chirps.min(dim='time')
        gm_mads['rain_mean'] = chirps.mean(dim='time')
        gm_mads['rain_max'] = chirps.max(dim='time')
        gm_mads['rain_range'] = gm_mads['rain_max'] - gm_mads['rain_min']
        gm_mads['rain_std'] = chirps.std(dim='time')
         
        # rainfall climatology
        gm_mads['rainclim_min'] = chpclim.min(dim='time')
        gm_mads['rainclim_mean'] = chpclim.mean(dim='time')
        gm_mads['rainclim_max'] = chpclim.max(dim='time')
        gm_mads['rainclim_range'] = gm_mads['rainclim_max'] - gm_mads['rainclim_min']
        gm_mads['rainclim_std'] = chpclim.std(dim='time')
                
        for band in gm_mads.data_vars:
            gm_mads = gm_mads.rename({band:band+era})
        
        return gm_mads
    
    epoch0 = fun(ds, chirps, chpclim, era='_S0')
    time, month = slice('2019-01', '2019-06'), slice(1, 6)
    epoch1 = fun(ds.sel(time=time), chirps.sel(time=month), chpclim.sel(time=month), era='_S1')
    time, month = slice('2019-07', '2019-12'), slice(7, 12)
    epoch2 = fun(ds.sel(time=time), chirps.sel(time=month), chpclim.sel(time=month), era='_S2')
    
    #slope
    url_slope = "https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif"
    slope = rio_slurp_xarray(url_slope, gbox=ds.geobox)
    slope = slope.to_dataset(name='slope')
    
    result = xr.merge([epoch0,
                       epoch1,
                       epoch2,
                       slope],compat='override')
    
    return result.squeeze()
Ejemplo n.º 19
0
def gm_mads_evi_rainfall(ds):
    """
    6 monthly and annual
    gm + mads
    evi stats (10, 50, 90 percentile, range, std)
    rainfall actual stats (min, mean, max, range, std) from monthly data
    rainfall clim stats (min, mean, max, range, std) from monthly data
    """
    dc = datacube.Datacube(app="training")
    ds = ds / 10000
    ds = ds.rename({"nir_1": "nir_wide", "nir_2": "nir"})
    ds1 = ds.sel(time=slice("2019-01", "2019-06"))
    ds2 = ds.sel(time=slice("2019-07", "2019-12"))

    chirps = []
    chpclim = []
    for m in range(1, 13):
        chirps.append(
            xr_reproject(
                assign_crs(
                    xr.open_rasterio(
                        f"/g/data/CHIRPS/monthly_2019/chirps-v2.0.2019.{m:02d}.tif"
                    )
                    .squeeze()
                    .expand_dims({"time": [m]}),
                    crs="epsg:4326",
                ),
                ds.geobox,
                "bilinear",
            )
        )
        chpclim.append(
            rio_slurp_xarray(
                f"https://deafrica-data-dev.s3.amazonaws.com/product-dev/deafrica_chpclim_50n_50s_{m:02d}.tif",
                gbox=ds.geobox,
                resapling="bilinear",
            ).expand_dims({"time": [m]})
        )

    chirps = xr.concat(chirps, dim="time")
    chpclim = xr.concat(chpclim, dim="time")

    def fun(ds, chirps, chpclim, era):
        ds = calculate_indices(
            ds, index=["EVI"], drop=False, normalise=False, collection="s2"
        )
        # geomedian and tmads
        gm_mads = xr_geomedian_tmad(ds)
        gm_mads = calculate_indices(
            gm_mads,
            index=["EVI", "NDVI", "LAI", "MNDWI"],
            drop=False,
            normalise=False,
            collection="s2",
        )

        gm_mads["sdev"] = -np.log(gm_mads["sdev"])
        gm_mads["bcdev"] = -np.log(gm_mads["bcdev"])
        gm_mads["edev"] = -np.log(gm_mads["edev"])

        # EVI stats
        gm_mads["evi_10"] = ds.EVI.quantile(0.1, dim="time")
        gm_mads["evi_50"] = ds.EVI.quantile(0.5, dim="time")
        gm_mads["evi_90"] = ds.EVI.quantile(0.9, dim="time")
        gm_mads["evi_range"] = gm_mads["evi_90"] - gm_mads["evi_10"]
        gm_mads["evi_std"] = ds.EVI.std(dim="time")

        # rainfall actual
        gm_mads["rain_min"] = chirps.min(dim="time")
        gm_mads["rain_mean"] = chirps.mean(dim="time")
        gm_mads["rain_max"] = chirps.max(dim="time")
        gm_mads["rain_range"] = gm_mads["rain_max"] - gm_mads["rain_min"]
        gm_mads["rain_std"] = chirps.std(dim="time")

        # rainfall climatology
        gm_mads["rainclim_min"] = chpclim.min(dim="time")
        gm_mads["rainclim_mean"] = chpclim.mean(dim="time")
        gm_mads["rainclim_max"] = chpclim.max(dim="time")
        gm_mads["rainclim_range"] = gm_mads["rainclim_max"] - gm_mads["rainclim_min"]
        gm_mads["rainclim_std"] = chpclim.std(dim="time")

        for band in gm_mads.data_vars:
            gm_mads = gm_mads.rename({band: band + era})

        return gm_mads

    epoch0 = fun(ds, chirps, chpclim, era="_S0")
    time, month = slice("2019-01", "2019-06"), slice(1, 6)
    epoch1 = fun(
        ds.sel(time=time), chirps.sel(time=month), chpclim.sel(time=month), era="_S1"
    )
    time, month = slice("2019-07", "2019-12"), slice(7, 12)
    epoch2 = fun(
        ds.sel(time=time), chirps.sel(time=month), chpclim.sel(time=month), era="_S2"
    )

    # slope
    url_slope = "https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif"
    slope = rio_slurp_xarray(url_slope, gbox=ds.geobox)
    slope = slope.to_dataset(name="slope")

    result = xr.merge([epoch0, epoch1, epoch2, slope], compat="override")

    return result.squeeze()
Ejemplo n.º 20
0
def annual_gm_mads_evi_training(ds):
    dc = datacube.Datacube(app="training")

    # grab gm+tmads
    gm_mads = dc.load(
        product="ga_s2_gm",
        time="2019",
        like=ds.geobox,
        measurements=[
            "red",
            "blue",
            "green",
            "nir",
            "swir_1",
            "swir_2",
            "red_edge_1",
            "red_edge_2",
            "red_edge_3",
            "SMAD",
            "BCMAD",
            "EMAD",
        ],
    )

    gm_mads["SMAD"] = -np.log(gm_mads["SMAD"])
    gm_mads["BCMAD"] = -np.log(gm_mads["BCMAD"])
    gm_mads["EMAD"] = -np.log(gm_mads["EMAD"] / 10000)

    # calculate band indices on gm
    gm_mads = calculate_indices(
        gm_mads, index=["EVI", "LAI", "MNDWI"], drop=False, collection="s2"
    )

    # normalise spectral GM bands 0-1
    for band in gm_mads.data_vars:
        if band not in ["SMAD", "BCMAD", "EMAD", "EVI", "LAI", "MNDWI"]:
            gm_mads[band] = gm_mads[band] / 10000

    # calculate EVI on annual timeseries
    evi = calculate_indices(
        ds, index=["EVI"], drop=True, normalise=True, collection="s2"
    )

    # EVI stats
    gm_mads["evi_std"] = evi.EVI.std(dim="time")
    gm_mads["evi_10"] = evi.EVI.quantile(0.1, dim="time")
    gm_mads["evi_25"] = evi.EVI.quantile(0.25, dim="time")
    gm_mads["evi_75"] = evi.EVI.quantile(0.75, dim="time")
    gm_mads["evi_90"] = evi.EVI.quantile(0.9, dim="time")
    gm_mads["evi_range"] = gm_mads["evi_90"] - gm_mads["evi_10"]

    # rainfall climatology
    chirps_S1 = xr_reproject(
        assign_crs(
            xr.open_rasterio(
                "/g/data/CHIRPS/cumulative_alltime/CHPclim_jan_jun_cumulative_rainfall.nc"
            ),
            crs="epsg:4326",
        ),
        ds.geobox,
        "bilinear",
    )

    chirps_S2 = xr_reproject(
        assign_crs(
            xr.open_rasterio(
                "/g/data/CHIRPS/cumulative_alltime/CHPclim_jul_dec_cumulative_rainfall.nc"
            ),
            crs="epsg:4326",
        ),
        ds.geobox,
        "bilinear",
    )

    gm_mads["rain_S1"] = chirps_S1
    gm_mads["rain_S2"] = chirps_S2

    # slope
    url_slope = "https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif"
    slope = rio_slurp_xarray(url_slope, gbox=ds.geobox)
    slope = slope.to_dataset(name="slope")  # .chunk({'x':2000,'y':2000})

    result = xr.merge([gm_mads, slope], compat="override")

    return result.squeeze()
def main(year, crs = 'EPSG:6933', res = 10):
    crs_code = crs.split(':')[1]
    dea_filename = f"deafrica_gmw_{year}_{crs_code}_{res}m.tif"
    if os.path.exists(dea_filename):
        print(f"{dea_filename} already exists")
        return

    # download extents if needed
    gmw_shp = f'GMW_001_GlobalMangroveWatch_{year}/01_Data/GMW_{year}_v2.shp'
    if not os.path.exists(gmw_shp):
        gmw_shp = download_and_unzip_gmw(year=year)

    # extract extents over Africa
    gmw = gpd.read_file(gmw_shp)

    deafrica_extent = gpd.read_file('https://github.com/digitalearthafrica/deafrica-extent/raw/master/africa-extent.json')
    deafrica_extent = deafrica_extent.to_crs(gmw.crs)

    # find everything within deafrica_extent
    gmw_africa = gpd.sjoin(gmw, deafrica_extent,op='intersects')
    # include additional in the sqaure bounding box
    bound = box(*gmw_africa.total_bounds).buffer(0.001)
    deafrica_extent_square = gpd.GeoDataFrame(gpd.GeoSeries(bound), columns=['geometry'], crs=gmw_africa.crs)
    gmw_africa = gpd.sjoin(gmw, deafrica_extent_square,op='intersects')

    ## output raster setting
    gmw_africa = gmw_africa.to_crs(crs)
    bounds = gmw_africa.total_bounds
    bounds = np.hstack([np.floor(bounds[:2]/10)*10, np.ceil(bounds[2:]/10)*10])

    #transform = Affine(res, 0.0, bounds[0], 0.0, -1*res, bounds[3])
    out_shape = int((bounds[3]-bounds[1])/res), int((bounds[2]-bounds[0])/res)

    #rasterize in tiles
    tile_size = 50000
    ny = np.ceil(out_shape[0]/tile_size).astype(int)
    nx = np.ceil(out_shape[1]/tile_size).astype(int)

    for iy in np.arange(ny):
        for ix in np.arange(nx):
            y0 = bounds[3]-iy*tile_size*res
            x0 = bounds[0]+ix*tile_size*res
            y1 = np.max([bounds[1], bounds[3]-(iy+1)*tile_size*res])
            x1 = np.min([bounds[2], bounds[0]+(ix+1)*tile_size*res])
            
            transform = Affine(res, 0.0, x0, 0.0, -1*res, y0)  # pixel ul
            sub_shape = np.abs((y1-y0)/res).astype(int), np.abs((x1-x0)/res).astype(int)
        
            arr = rasterize(shapes=gmw_africa.geometry,
                            out_shape=sub_shape,
                            transform=transform,
                            fill=0,
                            all_touched=True,
                            default_value=1,
                            dtype=np.uint8)
            
            xarr = xr.DataArray(arr,
                                # pixel center
                                coords={'y':y0-np.arange(sub_shape[0])*res-res/2,
                                        'x':x0+np.arange(sub_shape[1])*res+res/2},
                                dims=('y','x'),
                                name='gmw')
        
            xarr = assign_crs(xarr, str(crs))
            write_cog(xarr, 
                      f'gmw_africa_{year}_{ix}_{iy}.tif',
                      overwrite=True)
            
    cmd = f"gdalbuildvrt gmw_africa_{year}.vrt gmw_africa_{year}_*_*.tif"
    r = subprocess.call(cmd, shell=True)
            
    cmd = f"rio cogeo create --overview-level 0 gmw_africa_{year}.vrt deafrica_gmw_{year}.tif"
    r = subprocess.call(cmd, shell=True)
Ejemplo n.º 22
0
def temporal_statistics(da, stats):
    """
    Obtain generic temporal statistics using the hdstats temporal library:
    https://github.com/daleroberts/hdstats/blob/master/hdstats/ts.pyx
    
    last modified June 2020
    
    Parameters
    ----------
    da :  xarray.DataArray
        DataArray should contain a 3D time series.
    stats : list
        list of temporal statistics to calculate.
        Options include:
            'discordance' = 
            'f_std' = std of discrete fourier transform coefficients, returns
                      three layers: f_std_n1, f_std_n2, f_std_n3
            'f_mean' = mean of discrete fourier transform coefficients, returns
                       three layers: f_mean_n1, f_mean_n2, f_mean_n3
            'f_median' = median of discrete fourier transform coefficients, returns
                         three layers: f_median_n1, f_median_n2, f_median_n3
            'mean_change' = mean of discrete difference along time dimension
            'median_change' = median of discrete difference along time dimension
            'abs_change' = mean of absolute discrete difference along time dimension
            'complexity' = 
            'central_diff' = 
            'num_peaks' : The number of peaks in the timeseries, defined with a local
                          window of size 10.  NOTE: This statistic is very slow
    Outputs
    -------
        xarray.Dataset containing variables for the selected 
        temporal statistics
        
    """

    # if dask arrays then map the blocks
    if dask.is_dask_collection(da):
        if version.parse(xr.__version__) < version.parse("0.16.0"):
            raise TypeError(
                "Dask arrays are only supported by this function if using, " +
                "xarray v0.16, run da.compute() before passing dataArray.")

        # create a template that matches the final datasets dims & vars
        arr = da.isel(time=0).drop("time")

        # deal with the case where fourier is first in the list
        if stats[0] in ("f_std", "f_median", "f_mean"):
            template = xr.zeros_like(arr).to_dataset(name=stats[0] + "_n1")
            template[stats[0] + "_n2"] = xr.zeros_like(arr)
            template[stats[0] + "_n3"] = xr.zeros_like(arr)

            for stat in stats[1:]:
                if stat in ("f_std", "f_median", "f_mean"):
                    template[stat + "_n1"] = xr.zeros_like(arr)
                    template[stat + "_n2"] = xr.zeros_like(arr)
                    template[stat + "_n3"] = xr.zeros_like(arr)
                else:
                    template[stat] = xr.zeros_like(arr)
        else:
            template = xr.zeros_like(arr).to_dataset(name=stats[0])

            for stat in stats:
                if stat in ("f_std", "f_median", "f_mean"):
                    template[stat + "_n1"] = xr.zeros_like(arr)
                    template[stat + "_n2"] = xr.zeros_like(arr)
                    template[stat + "_n3"] = xr.zeros_like(arr)
                else:
                    template[stat] = xr.zeros_like(arr)
        try:
            template = template.drop('spatial_ref')
        except:
            pass

        # ensure the time chunk is set to -1
        da_all_time = da.chunk({"time": -1})

        # apply function across chunks
        lazy_ds = da_all_time.map_blocks(temporal_statistics,
                                         kwargs={"stats": stats},
                                         template=template)

        try:
            crs = da.geobox.crs
            lazy_ds = assign_crs(lazy_ds, str(crs))
        except:
            pass

        return lazy_ds

    # If stats supplied is not a list, convert to list.
    stats = stats if isinstance(stats, list) else [stats]

    # grab all the attributes of the xarray
    x, y, time, attrs = da.x, da.y, da.time, da.attrs

    # deal with any all-NaN pixels by filling with 0's
    mask = da.isnull().all("time")
    da = da.where(~mask, other=0)

    # complete timeseries
    print("Completing...")
    da = fast_completion(da)

    # ensure dim order is correct for functions
    da = da.transpose("y", "x", "time").values

    stats_dict = {
        "discordance": lambda da: hdstats.discordance(da, n=10),
        "f_std": lambda da: hdstats.fourier_std(da, n=3, step=5),
        "f_mean": lambda da: hdstats.fourier_mean(da, n=3, step=5),
        "f_median": lambda da: hdstats.fourier_median(da, n=3, step=5),
        "mean_change": lambda da: hdstats.mean_change(da),
        "median_change": lambda da: hdstats.median_change(da),
        "abs_change": lambda da: hdstats.mean_abs_change(da),
        "complexity": lambda da: hdstats.complexity(da),
        "central_diff": lambda da: hdstats.mean_central_diff(da),
        "num_peaks": lambda da: hdstats.number_peaks(da, 10),
    }

    print("   Statistics:")
    # if one of the fourier functions is first (or only)
    # stat in the list then we need to deal with this
    if stats[0] in ("f_std", "f_median", "f_mean"):
        print("      " + stats[0])
        stat_func = stats_dict.get(str(stats[0]))
        zz = stat_func(da)
        n1 = zz[:, :, 0]
        n2 = zz[:, :, 1]
        n3 = zz[:, :, 2]

        # intialise dataset with first statistic
        ds = xr.DataArray(n1,
                          attrs=attrs,
                          coords={
                              "x": x,
                              "y": y
                          },
                          dims=["y", "x"]).to_dataset(name=stats[0] + "_n1")

        # add other datasets
        for i, j in zip([n2, n3], ["n2", "n3"]):
            ds[stats[0] + "_" + j] = xr.DataArray(i,
                                                  attrs=attrs,
                                                  coords={
                                                      "x": x,
                                                      "y": y
                                                  },
                                                  dims=["y", "x"])
    else:
        # simpler if first function isn't fourier transform
        first_func = stats_dict.get(str(stats[0]))
        print("      " + stats[0])
        ds = first_func(da)

        # convert back to xarray dataset
        ds = xr.DataArray(ds,
                          attrs=attrs,
                          coords={
                              "x": x,
                              "y": y
                          },
                          dims=["y", "x"]).to_dataset(name=stats[0])

    # loop through the other functions
    for stat in stats[1:]:
        print("      " + stat)

        # handle the fourier transform examples
        if stat in ("f_std", "f_median", "f_mean"):
            stat_func = stats_dict.get(str(stat))
            zz = stat_func(da)
            n1 = zz[:, :, 0]
            n2 = zz[:, :, 1]
            n3 = zz[:, :, 2]

            for i, j in zip([n1, n2, n3], ["n1", "n2", "n3"]):
                ds[stat + "_" + j] = xr.DataArray(i,
                                                  attrs=attrs,
                                                  coords={
                                                      "x": x,
                                                      "y": y
                                                  },
                                                  dims=["y", "x"])

        else:
            # Select a stats function from the dictionary
            # and add to the dataset
            stat_func = stats_dict.get(str(stat))
            ds[stat] = xr.DataArray(stat_func(da),
                                    attrs=attrs,
                                    coords={
                                        "x": x,
                                        "y": y
                                    },
                                    dims=["y", "x"])

    # try to add back the geobox
    try:
        crs = da.geobox.crs
        ds = assign_crs(ds, str(crs))
    except:
        pass

    return ds
Ejemplo n.º 23
0
def xr_rasterize(gdf,
                 da,
                 attribute_col=False,
                 crs=None,
                 transform=None,
                 name=None,
                 x_dim='x',
                 y_dim='y',
                 export_tiff=None,
                 **rasterio_kwargs):
    """
    Rasterizes a geopandas.GeoDataFrame into an xarray.DataArray.
    
    Parameters
    ----------
    gdf : geopandas.GeoDataFrame
        A geopandas.GeoDataFrame object containing the vector/shapefile
        data you want to rasterise.
    da : xarray.DataArray or xarray.Dataset
        The shape, coordinates, dimensions, and transform of this object 
        are used to build the rasterized shapefile. It effectively 
        provides a template. The attributes of this object are also 
        appended to the output xarray.DataArray.
    attribute_col : string, optional
        Name of the attribute column in the geodataframe that the pixels 
        in the raster will contain.  If set to False, output will be a 
        boolean array of 1's and 0's.
    crs : str, optional
        CRS metadata to add to the output xarray. e.g. 'epsg:3577'.
        The function will attempt get this info from the input 
        GeoDataFrame first.
    transform : affine.Affine object, optional
        An affine.Affine object (e.g. `from affine import Affine; 
        Affine(30.0, 0.0, 548040.0, 0.0, -30.0, "6886890.0) giving the 
        affine transformation used to convert raster coordinates 
        (e.g. [0, 0]) to geographic coordinates. If none is provided, 
        the function will attempt to obtain an affine transformation 
        from the xarray object (e.g. either at `da.transform` or
        `da.geobox.transform`).
    x_dim : str, optional
        An optional string allowing you to override the xarray dimension 
        used for x coordinates. Defaults to 'x'. Useful, for example, 
        if x and y dims instead called 'lat' and 'lon'.   
    y_dim : str, optional
        An optional string allowing you to override the xarray dimension 
        used for y coordinates. Defaults to 'y'. Useful, for example, 
        if x and y dims instead called 'lat' and 'lon'.
    export_tiff: str, optional
        If a filepath is provided (e.g 'output/output.tif'), will export a
        geotiff file. A named array is required for this operation, if one
        is not supplied by the user a default name, 'data', is used
    **rasterio_kwargs : 
        A set of keyword arguments to rasterio.features.rasterize
        Can include: 'all_touched', 'merge_alg', 'dtype'.
    
    Returns
    -------
    xarr : xarray.DataArray
    
    """

    # Check for a crs object
    try:
        crs = da.geobox.crs
    except:
        try:
            crs = da.crs
        except:
            if crs is None:
                raise Exception(
                    "Please add a `crs` attribute to the "
                    "xarray.DataArray, or provide a CRS using the "
                    "function's `crs` parameter (e.g. crs='EPSG:3577')")

    # Check if transform is provided as a xarray.DataArray method.
    # If not, require supplied Affine
    if transform is None:
        try:
            # First, try to take transform info from geobox
            transform = da.geobox.transform
        # If no geobox
        except:
            try:
                # Try getting transform from 'transform' attribute
                transform = da.transform
            except:
                # If neither of those options work, raise an exception telling the
                # user to provide a transform
                raise Exception(
                    "Please provide an Affine transform object using the "
                    "`transform` parameter (e.g. `from affine import "
                    "Affine; Affine(30.0, 0.0, 548040.0, 0.0, -30.0, "
                    "6886890.0)`")

    # Grab the 2D dims (not time)
    try:
        dims = da.geobox.dims
    except:
        dims = y_dim, x_dim

    # Coords
    xy_coords = [da[dims[0]], da[dims[1]]]

    # Shape
    try:
        y, x = da.geobox.shape
    except:
        y, x = len(xy_coords[0]), len(xy_coords[1])

    # Reproject shapefile to match CRS of raster
    print(f'Rasterizing to match xarray.DataArray dimensions ({y}, {x})')

    try:
        gdf_reproj = gdf.to_crs(crs=crs)
    except:
        # Sometimes the crs can be a datacube utils CRS object
        # so convert to string before reprojecting
        gdf_reproj = gdf.to_crs(crs={'init': str(crs)})

    # If an attribute column is specified, rasterise using vector
    # attribute values. Otherwise, rasterise into a boolean array
    if attribute_col:
        # Use the geometry and attributes from `gdf` to create an iterable
        shapes = zip(gdf_reproj.geometry, gdf_reproj[attribute_col])
    else:
        # Use geometry directly (will produce a boolean numpy array)
        shapes = gdf_reproj.geometry

    # Rasterise shapes into an array
    arr = rasterio.features.rasterize(shapes=shapes,
                                      out_shape=(y, x),
                                      transform=transform,
                                      **rasterio_kwargs)

    # Convert result to a xarray.DataArray
    xarr = xr.DataArray(arr,
                        coords=xy_coords,
                        dims=dims,
                        attrs=da.attrs,
                        name=name if name else None)

    # Add back crs if xarr.attrs doesn't have it
    if xarr.geobox is None:
        xarr = assign_crs(xarr, str(crs))

    if export_tiff:
        print(f"Exporting GeoTIFF to {export_tiff}")
        write_cog(xarr, export_tiff, overwrite=True)

    return xarr
Ejemplo n.º 24
0
from datacube.testutils.io import rio_slurp_xarray, rio_slurp_reproject
from datacube.utils.geometry import GeoBox, box, CRS
from datacube.utils.cog import write_cog

# from sklearn.impute import SimpleImputer

#dc = datacube.Datacube(config='/home/547/sc0554/datacube.conf', env='lccs_dev')

#query = {'time': ('2015-01-01', '2015-12-31')}
#query['crs'] = 'EPSG:3577'

#data = dc.load(product='fc_percentile_albers_annual', measurements='PV_PC_90', **query)
data = xr.open_rasterio(
    '/g/data/r78/LCCS_Aberystwyth/urban_tests/test_sites_peter/perth_2015_gm.tif'
)
data = assign_crs(data, crs='epsg:3577')
# quickshift expects multiband images with bands in the last dimension
data = data.transpose()
fname = '/g/data/r78/LCCS_Aberystwyth/continental_run_april2020/2015/lccs_2015_L4_0.5.0.tif'
LCCS = rio_slurp_xarray(fname, gbox=data.geobox)
LCCS = LCCS.isel(band=0)
print("LCCS shape", LCCS.shape)
meta_d = LCCS.copy()  ##.squeeze().drop('time')
seg = felzenszwalb(LCCS.data.transpose())
#seg = quickshift(LCCS.data.transpose(), kernel_size=3, convert2lab=False, max_dist=10, ratio=0.5)
print('seg shape', seg.shape)
data_seg_med = scipy.ndimage.median(input=LCCS.data.transpose(),
                                    labels=seg,
                                    index=seg)
#data_seg_med = data_seg_med.squeeze("time").drop("time")
print("seg_med shape", data_seg_med.shape)
Ejemplo n.º 25
0
def xr_phenology(
    da,
    stats=[
        "SOS",
        "POS",
        "EOS",
        "Trough",
        "vSOS",
        "vPOS",
        "vEOS",
        "LOS",
        "AOS",
        "ROG",
        "ROS",
    ],
    method_sos="median",
    method_eos="median",
    complete='fast_complete',
    smoothing=None,
    show_progress=True,
):
    """
    Obtain land surface phenology metrics from an
    xarray.DataArray containing a timeseries of a 
    vegetation index like NDVI.

    last modified June 2020

    Parameters
    ----------
    da :  xarray.DataArray
        DataArray should contain a 2D or 3D time series of a
        vegetation index like NDVI, EVI
    stats : list
        list of phenological statistics to return. Regardless of
        the metrics returned, all statistics are calculated
        due to inter-dependencies between metrics.
        Options include:
            SOS = DOY of start of season
            POS = DOY of peak of season
            EOS = DOY of end of season
            vSOS = Value at start of season
            vPOS = Value at peak of season
            vEOS = Value at end of season
            Trough = Minimum value of season
            LOS = Length of season (DOY)
            AOS = Amplitude of season (in value units)
            ROG = Rate of greening
            ROS = Rate of senescence
    method_sos : str 
        If 'first' then vSOS is estimated as the first positive 
        slope on the greening side of the curve. If 'median',
        then vSOS is estimated as the median value of the postive
        slopes on the greening side of the curve.
    method_eos : str
        If 'last' then vEOS is estimated as the last negative slope
        on the senescing side of the curve. If 'median', then vEOS is
        estimated as the 'median' value of the negative slopes on the
        senescing side of the curve.
    complete : str
        If 'fast_complete', the timeseries will be completed (gap filled) using
        fast_completion(), if 'linear', time series with be completed using 
        da.interpolate_na(method='linear')
    smoothing : str
        If 'wiener', the timeseries will be smoothed using the
        scipy.signal.wiener filter with a window size of 3.  If 'rolling_mean', 
        then timeseries is smoothed using a rolling mean with a window size of 3.
        If set to 'linear', will be smoothed using da.resample(time='1W').interpolate('linear')

    Outputs
    -------
        xarray.Dataset containing variables for the selected 
        phenology statistics 

    """
    # Check inputs before running calculations
    if dask.is_dask_collection(da):
        if version.parse(xr.__version__) < version.parse('0.16.0'):
            raise TypeError(
                "Dask arrays are not currently supported by this function, " +
                "run da.compute() before passing dataArray.")
        stats_dtype = {
            "SOS": np.int16,
            "POS": np.int16,
            "EOS": np.int16,
            "Trough": np.float32,
            "vSOS": np.float32,
            "vPOS": np.float32,
            "vEOS": np.float32,
            "LOS": np.int16,
            "AOS": np.float32,
            "ROG": np.float32,
            "ROS": np.float32,
        }
        da_template = da.isel(time=0).drop('time')
        template = xr.Dataset({
            var_name: da_template.astype(var_dtype)
            for var_name, var_dtype in stats_dtype.items() if var_name in stats
        })
        da_all_time = da.chunk({'time': -1})

        lazy_phenology = da_all_time.map_blocks(xr_phenology,
                                                kwargs=dict(
                                                    stats=stats,
                                                    method_sos=method_sos,
                                                    method_eos=method_eos,
                                                    complete=complete,
                                                    smoothing=smoothing,
                                                ),
                                                template=xr.Dataset(template))

        try:
            crs = da.geobox.crs
            lazy_phenology = assign_crs(lazy_phenology, str(crs))
        except:
            pass

        return lazy_phenology

    if method_sos not in ("median", "first"):
        raise ValueError("method_sos should be either 'median' or 'first'")

    if method_eos not in ("median", "last"):
        raise ValueError("method_eos should be either 'median' or 'last'")

    # If stats supplied is not a list, convert to list.
    stats = stats if isinstance(stats, list) else [stats]

    #try to grab the crs info
    try:
        crs = da.geobox.crs
    except:
        pass

    # complete timeseries
    if complete is not None:

        if complete == 'fast_complete':

            if len(da.shape) == 1:
                print(
                    "fast_complete does not operate on 1D timeseries, using 'linear' instead"
                )
                da = da.interpolate_na(dim='time', method='linear')

            else:
                print("Completing using fast_complete...")
                da = fast_completion(da)

        if complete == 'linear':
            print("Completing using linear interp...")
            da = da.interpolate_na(dim='time', method='linear')

    if smoothing is not None:

        if smoothing == "wiener":
            if len(da.shape) == 1:
                print(
                    "wiener method does not operate on 1D timeseries, using 'rolling_mean' instead"
                )
                da = da.rolling(time=3, min_periods=1).mean()

            else:
                print("   Smoothing with wiener filter...")
                da = smooth(da)

        if smoothing == "rolling_mean":
            print("   Smoothing with rolling mean...")
            da = da.rolling(time=3, min_periods=1).mean()

        if smoothing == 'linear':
            print("    Smoothing using linear interpolation...")
            da = da.resample(time='1W').interpolate('linear')

    # remove any remaining all-NaN pixels
    mask = da.isnull().all("time")
    da = da.where(~mask, other=0)

    # calculate the statistics
    print("      Phenology...")
    vpos = _vpos(da)
    pos = _pos(da)
    trough = _trough(da)
    aos = _aos(vpos, trough)
    vsos = _vsos(da, pos, method_sos=method_sos)
    sos = _sos(vsos)
    veos = _veos(da, pos, method_eos=method_eos)
    eos = _eos(veos)
    los = _los(da, eos, sos)
    rog = _rog(vpos, vsos, pos, sos)
    ros = _ros(veos, vpos, eos, pos)

    # Dictionary containing the statistics
    stats_dict = {
        "SOS": sos.astype(np.int16),
        "EOS": eos.astype(np.int16),
        "vSOS": vsos.astype(np.float32),
        "vPOS": vpos.astype(np.float32),
        "Trough": trough.astype(np.float32),
        "POS": pos.astype(np.int16),
        "vEOS": veos.astype(np.float32),
        "LOS": los.astype(np.int16),
        "AOS": aos.astype(np.float32),
        "ROG": rog.astype(np.float32),
        "ROS": ros.astype(np.float32),
    }

    # intialise dataset with first statistic
    ds = stats_dict[stats[0]].to_dataset(name=stats[0])

    # add the other stats to the dataset
    for stat in stats[1:]:
        print("         " + stat)
        stats_keep = stats_dict.get(stat)
        ds[stat] = stats_dict[stat]

    try:
        ds = assign_crs(ds, str(crs))
    except:
        pass

    return ds.drop('time')
Ejemplo n.º 26
0
def predict_xr(
    model,
    input_xr,
    chunk_size=None,
    persist=True,
    proba=False,
    clean=False,
    return_input=False,
):
    """
    Using dask-ml ParallelPostfit(), runs  the parallel
    predict and predict_proba methods of sklearn
    estimators. Useful for running predictions
    on a larger-than-RAM datasets.

    Last modified: September 2020

    Parameters
    ----------
    model : scikit-learn model or compatible object
        Must have a .predict() method that takes numpy arrays.
    input_xr : xarray.DataArray or xarray.Dataset.
        Must have dimensions 'x' and 'y'
    chunk_size : int
        The dask chunk size to use on the flattened array. If this
        is left as None, then the chunks size is inferred from the
        .chunks() method on the `input_xr`
    persist : bool
        If True, and proba=True, then 'input_xr' data will be
        loaded into distributed memory. This will ensure data
        is not loaded twice for the prediction of probabilities,
        but this will only work if the data is not larger than RAM.
    proba : bool
        If True, predict probabilities. This only applies if the
        model has a .predict_proba() method
    clean : bool
        If True, remove Infs and NaNs from input and output arrays
    return_input : bool
        If True, then the data variables in the 'input_xr' dataset will
        be appended to the output xarray dataset.

    Returns
    ----------
    output_xr : xarray.Dataset
        An xarray.Dataset containing the prediction output from model
        with input_xr as input, if proba=True then dataset will also contain
        the prediciton probabilities. Has the same spatiotemporal structure
        as input_xr.

    """
    if chunk_size is None:
        chunk_size = int(input_xr.chunks["x"][0]) * int(
            input_xr.chunks["y"][0])

    # convert model to dask predict
    model = ParallelPostFit(model)

    # with joblib.parallel_backend("dask"):
    x, y, crs = input_xr.x, input_xr.y, input_xr.geobox.crs

    input_data = []

    for var_name in input_xr.data_vars:
        input_data.append(input_xr[var_name])

    input_data_flattened = []
    # TODO: transfer to dask dataframe
    for arr in input_data:
        data = arr.data.flatten().rechunk(chunk_size)
        input_data_flattened.append(data)

    # reshape for prediction
    input_data_flattened = da.array(input_data_flattened).transpose()

    if clean:
        input_data_flattened = da.where(da.isfinite(input_data_flattened),
                                        input_data_flattened, 0)

    if proba and persist:
        # persisting data so we don't require loading all the data twice
        input_data_flattened = input_data_flattened.persist()

    # apply the classification
    print("   predicting...")
    out_class = model.predict(input_data_flattened)

    # Mask out NaN or Inf values in results
    if clean:
        out_class = da.where(da.isfinite(out_class), out_class, 0)

    # Reshape when writing out
    out_class = out_class.reshape(len(y), len(x))

    # stack back into xarray
    output_xr = xr.DataArray(out_class,
                             coords={
                                 "x": x,
                                 "y": y
                             },
                             dims=["y", "x"])

    output_xr = output_xr.to_dataset(name="Predictions")

    if proba:
        print("   probabilities...")
        out_proba = model.predict_proba(input_data_flattened)

        # convert to %
        out_proba = da.max(out_proba, axis=1) * 100.0

        if clean:
            out_proba = da.where(da.isfinite(out_proba), out_proba, 0)

        out_proba = out_proba.reshape(len(y), len(x))

        out_proba = xr.DataArray(out_proba,
                                 coords={
                                     "x": x,
                                     "y": y
                                 },
                                 dims=["y", "x"])
        output_xr["Probabilities"] = out_proba

    if return_input:
        print("   input features...")
        # unflatten the input_data_flattened array and append
        # to the output_xr containin the predictions
        arr = input_xr.to_array()
        stacked = arr.stack(z=["y", "x"])
        # handle multivariable output
        output_px_shape = ()
        if len(input_data_flattened.shape[1:]):
            output_px_shape = input_data_flattened.shape[1:]

        output_features = input_data_flattened.reshape(
            (len(stacked.z), *output_px_shape))

        # set the stacked coordinate to match the input
        output_features = xr.DataArray(
            output_features,
            coords={
                "z": stacked["z"]
            },
            dims=[
                "z",
                *[
                    "output_dim_" + str(idx)
                    for idx in range(len(output_px_shape))
                ],
            ],
        ).unstack()

        # convert to dataset and rename arrays
        output_features = output_features.to_dataset(dim="output_dim_0")
        data_vars = list(input_xr.data_vars)
        output_features = output_features.rename(
            {i: j
             for i, j in zip(output_features.data_vars, data_vars)}  # noqa pylint: disable=unnecessary-comprehension
        )

        # merge with predictions
        output_xr = xr.merge([output_xr, output_features], compat="override")

    return assign_crs(output_xr, str(crs))
Ejemplo n.º 27
0
def xr_geomedian_tmad(ds, axis='time', where=None, **kw):
    """
    :param ds: xr.Dataset|xr.DataArray|numpy array
    Other parameters:
    **kwargs -- passed on to pcm.gnmpcm
       maxiters   : int         1000
       eps        : float       0.0001
       num_threads: int| None   None
    """

    import hdstats
    def gm_tmad(arr, **kw):
        """
        arr: a high dimensional numpy array where the last dimension will be reduced. 
    
        returns: a numpy array with one less dimension than input.
        """
        gm = hdstats.nangeomedian_pcm(arr, **kw)
        nt = kw.pop('num_threads', None)
        emad = hdstats.emad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
        smad = hdstats.smad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
        bcmad = hdstats.bcmad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
        return np.concatenate([gm, emad, smad, bcmad], axis=-1)


    def norm_input(ds, axis):
        if isinstance(ds, xr.DataArray):
            xx = ds
            if len(xx.dims) != 4:
                raise ValueError("Expect 4 dimensions on input: y,x,band,time")
            if axis is not None and xx.dims[3] != axis:
                raise ValueError(f"Can only reduce last dimension, expect: y,x,band,{axis}")
            return None, xx, xx.data
        elif isinstance(ds, xr.Dataset):
            xx = reshape_for_geomedian(ds, axis)
            return ds, xx, xx.data
        else:  # assume numpy or similar
            xx_data = ds
            if xx_data.ndim != 4:
                raise ValueError("Expect 4 dimensions on input: y,x,band,time")
            return None, None, xx_data

    kw.setdefault('nocheck', False)
    kw.setdefault('num_threads', 1)
    kw.setdefault('eps', 1e-6)

    ds, xx, xx_data = norm_input(ds, axis)
    is_dask = dask.is_dask_collection(xx_data)

    if where is not None:
        if is_dask:
            raise NotImplementedError("Dask version doesn't support output masking currently")

        if where.shape != xx_data.shape[:2]:
            raise ValueError("Shape for `where` parameter doesn't match")
        set_nan = ~where
    else:
        set_nan = None

    if is_dask:
        if xx_data.shape[-2:] != xx_data.chunksize[-2:]:
            xx_data = xx_data.rechunk(xx_data.chunksize[:2] + (-1, -1))

        data = da.map_blocks(lambda x: gm_tmad(x, **kw),
                             xx_data,
                             name=randomize('geomedian'),
                             dtype=xx_data.dtype, 
                             chunks=xx_data.chunks[:-2] + (xx_data.chunks[-2][0]+3,),
                             drop_axis=3)
    else:
        data = gm_tmad(xx_data, **kw)

    if set_nan is not None:
        data[set_nan, :] = np.nan

    if xx is None:
        return data

    dims = xx.dims[:-1]
    cc = {k: xx.coords[k] for k in dims}
    cc[dims[-1]] = np.hstack([xx.coords[dims[-1]].values,['edev', 'sdev', 'bcdev']])
    xx_out = xr.DataArray(data, dims=dims, coords=cc)

    if ds is None:
        xx_out.attrs.update(xx.attrs)
        return xx_out

    ds_out = xx_out.to_dataset(dim='band')
    for b in ds.data_vars.keys():
        src, dst = ds[b], ds_out[b]
        dst.attrs.update(src.attrs)

    return assign_crs(ds_out, crs=ds.geobox.crs)
Ejemplo n.º 28
0
def post_processing(
    predicted: xr.Dataset, urls: Dict[str, Any]
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
    """
    Run the delayed post_processing functions, then create a lazy
    xr.Dataset to satisfy odc-stats
    """
    dc = Datacube(app="whatever")

    # grab predictions and proba for post process filtering
    predict = predicted.Predictions
    proba = predicted.Probabilities
    proba = proba.where(predict == 1, 100 - proba)  # crop proba only

    # ------image seg and filtering -------------
    # write out ndvi for image seg
    ndvi = assign_crs(predicted[["NDVI_S1", "NDVI_S2"]], crs=predicted.geobox.crs)

    # call function with dask delayed
    filtered = image_segmentation(ndvi, predict)

    # convert delayed object to dask array
    filtered = dask.array.from_delayed(
        filtered.squeeze(), shape=predict.shape, dtype=np.uint8
    )

    # convert dask array to xr.Datarray
    filtered = xr.DataArray(filtered, coords=predict.coords, attrs=predict.attrs)

    # --Post process masking----------------------------------------

    # merge back together for masking
    ds = xr.Dataset({"mask": predict, "prob": proba, "filtered": filtered})

    # mask out classification beyond AEZ boundary
    gdf = gpd.read_file(urls["aez"])
    with HiddenPrints():
        mask = xr_rasterize(gdf, predicted)
    mask = mask.chunk({})
    ds = ds.where(mask, 0)

    # mask with WDPA
    wdpa = rio_slurp_xarray(urls["wdpa"], gbox=predicted.geobox)
    wdpa = wdpa.chunk({})
    wdpa = wdpa.astype(bool)
    ds = ds.where(~wdpa, 0)

    # mask with WOFS
    wofs=dc.load(product='wofs_ls_summary_annual',
                 like=predicted.geobox,
                 dask_chunks={},
                 time=('2019'))
    wofs=wofs.frequency > 0.20 # threshold
    ds=ds.where(~wofs, 0)

    # mask steep slopes
    slope = rio_slurp_xarray(urls["slope"], gbox=predicted.geobox)
    slope = slope.chunk({})
    slope = slope > 50
    ds = ds.where(~slope, 0)

    # mask where the elevation is above 3600m
    elevation = dc.load(product="dem_srtm", like=predicted.geobox, dask_chunks={})
    elevation = elevation.elevation > 3600  # threshold
    ds = ds.where(~elevation.squeeze(), 0)

    return ds.squeeze()