示例#1
0
def load_objs(objects):
    if isinstance(objects, PurePath):
        objects = str(objects)
    # Load objects
    # TODO: Read in chunks, parallelize, recombine and write
    logger.info('Reading in objects...')
    # objs = gpd.read_file(objects)
    objs = read_vec(objects)
    logger.info('Objects found: {:,}'.format(len(objs)))

    return objs
示例#2
0
    def __init__(self, objects_path, value_fields=None):
        if isinstance(objects_path, gpd.GeoDataFrame):
            self.objects = copy.deepcopy(objects_path)
            self.objects_path = None
        else:
            self.objects_path = objects_path
            self.objects = read_vec(objects_path)

        logger.info('Loaded {:,} objects.'.format(len(self.objects)))

        # Field names
        self.nebs_fld = 'neighbors'
        self._area_fld = 'area'
        self.pseudo_area_fld = 'pseudo_area'
        self.compact_fld = 'compactness'
        self.class_fld = 'class'

        # Merge column names
        self.mc_fld = 'merge_candidates'
        self.mp_fld = 'merge_path'
        # Inits to True, marked False if merged, or considered and unmergeable
        self.m_fld = 'mergeable'
        self.m_seed_fld = 'merge_seed'
        self.m_ct_fld = 'merge_count'
        self.continue_iter = 'continue_iter'
        self.mergeable_ids = None

        # List of (field_name, summary_stat) to be recalculated after merging
        self.value_fields = self._parse_value_fields(value_fields)

        # List of boolean fields holding result of apply a rule
        self.rule_fields = []
        # Properties calculated on demand
        self._num_objs = None
        self._fields = list(self.objects.columns)
        self._object_stats = None
        self._area = None
        # Neighbor value fields
        self.nv_fields = list()
        self.objects[self.nebs_fld] = np.NaN
        # Rules
        self._rule_fld_name = 'in_field'  # field name in rule dictionaries

        # TODO: check for unique index, create if not
        # Name index if unnamed
        if not self.objects.index.name:
            self.objects.index.name = 'index'
        if not self.objects.index.is_unique:
            logger.warning('Non-unique index not supported.')
示例#3
0
def get_footprint_dems(footprint_path, filepath=get_filepath_field(),
                       dem_name='DEM_NAME', dem_path_fld='dem_path',
                       dem_exist_fld='dem_exist', location=None,
                       use_terrranova=False):
    """Load Footprint - Check for existence of DEMs"""
    logger.debug('Footprint type: {}'.format(type(footprint_path)))
    if isinstance(footprint_path, list):
        # List of paths
        fp = pd.DataFrame({dem_path_fld: footprint_path})
    else:
        if isinstance(footprint_path, str):
            if os.path.splitext(footprint_path) == '.txt':
                # Paths in text file
                with open(footprint_path, 'r') as src:
                    content = src.readlines()
                fp = pd.DataFrame({dem_path_fld: content})
            else:
                # Load footprint
                logger.info('Loading DEM footprint...')
                # fp = gpd.read_file(footprint_path)
                fp = read_vec(footprint_path)
        # Vector file of footprints with path field
        elif isinstance(footprint_path, gpd.GeoDataFrame):
            fp = footprint_path
    if location is None:
        fp[dem_path_fld] = fp.apply(lambda x: get_dem_path(x[filepath], x[dem_name]), axis=1)
    else:
        fp[dem_path_fld] = fp[location]
        if platform.system() == 'Windows':
            fp[dem_path_fld] = fp[dem_path_fld].apply(lambda x: mnt2v(x))

    if use_terranova:
        fp[dem_path_fld] = fp[dem_path_fld].apply(lambda x: x.replace(r"pgc\data\elev",
                                                                      r"pgc\terrnva_data\elev"))
    num_fps = len(fp)
    logger.info('Records found: {:,}'.format(num_fps))

    logger.info('Verifying DEMs existence at location indicated...')
    fp[dem_exist_fld] = fp.apply(lambda x: os.path.exists(x[dem_path_fld]), axis=1)

    dem_paths = list(fp[fp[dem_exist_fld] == True][dem_path_fld])
    num_exist_fps = len(fp)
    if num_exist_fps != num_fps:
        logger.warning('Filepaths could not be found for {} records, skipping...'.format(num_fps - num_exist_fps))

    return dem_paths
示例#4
0
def mask_objs(objs, mask_on, out_mask_img=None, out_mask_vec=None):
    if out_mask_img is None:
        out_mask_img = r'/vsimem/temp_mask.tif'
    if out_mask_vec is None:
        out_mask_vec = r'/vsimem/temp_mask.shp'

    # Write mask vector (and raster if desired)
    logger.info('Creating mask from raster: {}'.format(mask_on))
    Raster(mask_on).WriteMaskVector(out_vec=out_mask_vec,
                                    out_mask_img=out_mask_img)
    # mask = gpd.read_file(out_mask_vec)
    mask = read_vec(out_mask_vec)
    not_mask = mask[mask.iloc[:, 0] != '1']

    # Select only objects in valid areas of mask
    logger.info('Removing objects in masked areas...')
    # TODO: use centroids for faster cleanup - Begin
    # objs_centroids = copy.deepcopy(objs).set_geometry(objs.geometry.centroid)
    # use centroids for faster cleanup - End

    keep_objs = gpd.overlay(objs, not_mask)
    logger.info('Objects kept: {:,}'.format(len(keep_objs)))

    return keep_objs
示例#5
0
def calc_zonal_stats(shp,
                     rasters,
                     names=None,
                     stats=['min', 'max', 'mean', 'count', 'median'],
                     area=True,
                     compactness=False,
                     roundness=False,
                     out_path=None):
    """
    Calculate zonal statistics on the given vector file
    for each raster provided.

    Parameters
    ----------
    shp : os.path.abspath
        Vector file to compute zonal statistics for the features in.
    out_path : os.path.abspath
        Path to write vector file with computed stats. Default is to
        add '_stats' suffix before file extension.
    rasters : list or os.path.abspath
        List of rasters to compute zonal statistics for.
        Or path to .txt file of raster paths (one per line)
        or path to .json file of
            name: {path: /path/to/raster.tif, stats: ['mean']}.
    names : list
        List of names to use as prefixes for created stats. Order
        is order of rasters.
    stats : list, optional
        List of statistics to calculate. The default is None.
    area : bool
        True to also compute area of each feature in units of
        projection.
    compactness : bool
        True to also compute compactness of each object
    roundness : bool
        True to also compute roundess of each object

    Returns
    -------
    out_path.

    """
    # Load data
    if isinstance(shp, gpd.GeoDataFrame):
        seg = shp
    else:
        logger.info('Reading in segments from: {}...'.format(shp))
        seg = read_vec(shp)
    logger.info('Segments found: {:,}'.format(len(seg)))

    # Determine rasters input type
    # TODO: Fix logic here, what if a bad path is passed?
    if len(rasters) == 1:
        if os.path.exists(rasters[0]):
            logger.info('Reading raster file...')
            ext = os.path.splitext(rasters[0])[1]
            if ext == '.txt':
                # Assume text file of raster paths, read into list
                logger.info('Reading rasters from text file: '
                            '{}'.format(rasters[0]))
                with open(rasters[0], 'r') as src:
                    content = src.readlines()
                    rasters = [c.strip() for c in content]
                    rasters, names = zip(*(r.split("~") for r in rasters))
                    logger.info('Located rasters:'.format('\n'.join(rasters)))
                    for r, n in zip(rasters, names):
                        logger.info('{}: {}'.format(n, r))
                # Create list of lists of stats passed, one for each raster
                stats = [stats for i in range(len(rasters))]
            elif ext == '.json':
                logger.info('Reading rasters from json file:'
                            ' {}'.format(rasters[0]))
                rasters, names, stats, bands = load_stats_dict(rasters[0])
            else:
                # Raster paths directly passed
                stats = [stats for i in range(len(rasters))]
    elif isinstance(rasters, dict):
        rasters, names, stats, bands = load_stats_dict(rasters)

    # Confirm all rasters exist before starting
    for r in rasters:
        if not os.path.exists(r):
            logger.error('Raster does not exist: {}'.format(r))
            logger.error('FileNotFoundError')

    # Iterate rasters and compute stats for each
    for r, n, s, bs in zip(rasters, names, stats, bands):
        if bs is None:
            # Split custom stat functions from built-in options
            accepted_stats = [
                'min', 'max', 'median', 'sum', 'std', 'mean', 'unique',
                'range', 'majority'
            ]
            stats_acc = [
                k for k in s
                if k in accepted_stats or k.startswith('percentile_')
            ]
            # Assume any key not in accepted_stats is a name:custom_fxn
            custom_stats = [k for k in stats if k not in accepted_stats]
            custom_stats_dict = {}
            # for cs in custom_stats:
            #     custom_stats[cs] = custom_stat_fxn(cs)

            seg = compute_stats(gdf=seg, raster=r, name=n, stats=stats_acc)
        else:
            # Compute stats for each band
            for b in bs:
                stats_dict = {x: '{}b{}_{}'.format(n, b, x) for x in s}
                seg = compute_stats(gdf=seg,
                                    raster=r,
                                    stats=stats_dict,
                                    renamer=stats_dict,
                                    band=b)

    # Area recording
    if area:
        seg['area_zs'] = seg.geometry.area

    # Compactness: Polsby-Popper Score -- 1 = circle
    if compactness:
        seg = apply_compactness(seg)

    if roundness:
        seg = apply_roundness(seg)

    # Write segments with stats to new shapefile
    if not out_path:
        out_path = os.path.join(
            os.path.dirname(shp),
            '{}_stats.shp'.format(os.path.basename(shp).split('.')[0]))
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    logger.info('Writing segments with statistics to: {}'.format(out_path))
    # driver = auto_detect_ogr_driver(out_path, name_only=True)
    # seg.to_file(out_path, driver=driver)
    write_gdf(seg, out_path)

    return out_path
示例#6
0
def classify_rts(sub_objects_path,
                 super_objects_path,
                 headwall_candidates_out=None,
                 headwall_candidates_centroid_out=None,
                 rts_predis_out=None,
                 rts_candidates_out=None,
                 aoi_path=None,
                 headwall_candidates_in=None,
                 aoi=None):
    logger.info('Classifying RTS...')

    #%% RULESET
    # Headwall Rules
    logger.info('Setting up headwall candidate rules...')
    # Ruggedness
    r_ruggedness = create_rule(rule_type=threshold_rule,
                               in_field=rug_mean,
                               op=operator.gt,
                               threshold=0.2,
                               out_field=True)
    # Surface Area Ratio
    r_saratio = create_rule(rule_type=threshold_rule,
                            in_field=sa_rat_mean,
                            op=operator.gt,
                            threshold=1.01,
                            out_field=True)
    # Slope (min)
    r_slope_min = create_rule(rule_type=threshold_rule,
                              in_field=slope_mean,
                              op=operator.gt,
                              threshold=8,
                              out_field=True)
    # Slope (max)
    r_slope_max = create_rule(rule_type=threshold_rule,
                              in_field=slope_mean,
                              op=operator.lt,
                              threshold=25,
                              out_field=True)
    # NDVI
    r_ndvi = create_rule(rule_type=threshold_rule,
                         in_field=ndvi_mean,
                         op=operator.lt,
                         threshold=0,
                         out_field=True)
    # MED
    r_med = create_rule(rule_type=threshold_rule,
                        in_field=med_mean,
                        op=operator.lt,
                        threshold=0,
                        out_field=True)
    # Curvature (high)
    r_curve = create_rule(rule_type=threshold_rule,
                          in_field=cur_mean,
                          op=operator.gt,
                          threshold=2.5,
                          out_field=True)
    # Difference in DEMs
    r_delev = create_rule(rule_type=threshold_rule,
                          in_field=delev_mean,
                          op=operator.lt,
                          threshold=-0.5,
                          out_field=True)

    # All simple threshold rules
    r_simple_thresholds = [
        r_ruggedness, r_saratio, r_slope_min, r_slope_max, r_ndvi, r_med,
        r_curve, r_delev
    ]

    # Adjacency rules
    # Adjacent Curvature
    r_adj_high_curv = create_rule(rule_type=adj_or_is_rule,
                                  in_field=cur_mean,
                                  op=operator.gt,
                                  threshold=30,
                                  out_field=True)
    r_adj_low_curv = create_rule(
        rule_type=adj_or_is_rule,
        in_field=cur_mean,
        op=operator.lt,
        threshold=-15,  # -30
        out_field=True)
    # Adjacent MED
    r_adj_low_med = create_rule(rule_type=adj_or_is_rule,
                                in_field=med_mean,
                                op=operator.lt,
                                threshold=-0.2,
                                out_field=True)
    # Adjacent or is high edge
    # r_adh_high_edge = create_rule(rule_type=adj_or_is_rule,
    #                               in_field=edge_mean,
    #                               op=operator.gt,
    #                               threshold=0.18,
    #                               out_field=True)
    # All adjacent rules
    r_adj_rules = [r_adj_low_curv, r_adj_high_curv, r_adj_low_med]

    #%% RTS Rules
    logger.info('Setting up RTS candidate rules...')
    r_rts_ndvi = create_rule(rule_type=threshold_rule,
                             in_field=ndvi_mean,
                             op=operator.lt,
                             threshold=0,
                             out_field=True)

    r_rts_med = create_rule(rule_type=threshold_rule,
                            in_field=med_mean,
                            op=operator.lt,
                            threshold=0.1,
                            out_field=True)

    r_rts_slope_low = create_rule(rule_type=threshold_rule,
                                  in_field=slope_mean,
                                  op=operator.gt,
                                  threshold=3,
                                  out_field=True)

    r_rts_slope_high = create_rule(rule_type=threshold_rule,
                                   in_field=slope_mean,
                                   op=operator.lt,
                                   threshold=20,
                                   out_field=True)

    r_rts_delev = create_rule(rule_type=threshold_rule,
                              in_field=delev_mean,
                              op=operator.lt,
                              threshold=-0.5,
                              out_field=True)

    r_rts_conhw = create_rule(rule_type=threshold_rule,
                              in_field=contains_hw,
                              op=operator.eq,
                              threshold=True,
                              out_field=True)

    r_rts_simple_thresholds = [
        r_rts_ndvi, r_rts_med, r_rts_slope_low, r_rts_slope_high, r_rts_delev,
        r_rts_conhw
    ]

    #%% HEADWALL CANDIDATES
    logger.info('Classifying headwall candidate objects...')
    #%% Load candidate headwall objects
    if not headwall_candidates_in:
        logger.info('Loading headwall candidate objects...')
        if aoi_path:
            # aoi = gpd.read_file(aoi_path)
            aoi = read_vec(aoi_path)
            logger.info('Subsetting objects to AOI...')
            gdf = select_in_aoi(read_vec(sub_objects_path), aoi, centroid=True)
            hwc = ImageObjects(objects_path=gdf, value_fields=value_fields)
        else:
            hwc = ImageObjects(objects_path=sub_objects_path,
                               value_fields=value_fields)

        #%% Classify headwalls
        logger.info('Determining headwall candidates...')
        hwc.classify_objects(hw_candidate,
                             threshold_rules=r_simple_thresholds,
                             adj_rules=r_adj_rules)
        logger.info('Headwall candidates found: {:,}'.format(
            len(hwc.objects[hwc.objects[hwc.class_fld] == hw_candidate])))

        #%% Write headwall candidates
        logger.info('Writing headwall candidates...')
        hwc.write_objects(headwall_candidates_out,
                          to_str_cols=to_str_cols,
                          overwrite=True)
        # if headwall_candidates_centroid_out:
        #     hwc_centroid = ImageObjects(
        #         copy.deepcopy(
        #             hwc.objects.set_geometry(hwc.objects.geometry.centroid)))
        #     hwc_centroid.write_objects(headwall_candidates_centroid_out,
        #                                overwrite=True)
    else:
        hwc = ImageObjects(objects_path=headwall_candidates_in,
                           value_fields=value_fields)

    #%% RETROGRESSIVE THAW SLUMPS
    #%% Load super objects
    logger.info('Loading RTS candidate objects...')
    so = ImageObjects(super_objects_path, value_fields=value_fields)
    logger.info('Determining RTS candidates...')

    #%% Find objects that contain headwalls of a higher elevation than
    # themselves
    so.objects[contains_hw_gtr] = so.objects.apply(
        lambda x: overlay_any_objects(x.geometry,
                                      hwc.objects[hwc.objects[hwc.class_fld] ==
                                                  hw_candidate],
                                      predicate='contains',
                                      threshold=x[elev_mean],
                                      other_value_field=elev_mean,
                                      op=operator.gt),
        axis=1)
    so.objects[contains_hw] = so.objects.apply(lambda x: overlay_any_objects(
        x.geometry,
        hwc.objects[hwc.objects[hwc.class_fld] == hw_candidate],
        predicate='contains',
    ),
                                               axis=1)
    so.objects[contains_hw_cent] = so.objects.apply(
        lambda x: overlay_any_objects(x.geometry,
                                      hwc.objects[hwc.objects[hwc.class_fld] ==
                                                  hw_candidate],
                                      predicate='contains',
                                      others_centroid=True),
        axis=1)
    so.objects[contains_hw_gtr] = so.objects.apply(
        lambda x: overlay_any_objects(x.geometry,
                                      hwc.objects[hwc.objects[hwc.class_fld] ==
                                                  hw_candidate],
                                      predicate='contains',
                                      threshold=x[elev_mean],
                                      other_value_field=elev_mean,
                                      op=operator.gt,
                                      others_centroid=True),
        axis=1)

    #%% Classify
    so.classify_objects(class_name=rts_candidate,
                        threshold_rules=r_rts_simple_thresholds)
    # # Add bool field for RTS candidate or not
    # so.objects[rts_cand_bool] = np.where(so.objects[so.class_fld] == rts_candidate,
    #                                   1,
    #                                   0)

    logger.info('RTS candidates found: {}'.format(
        len(so.objects[so.objects[so.class_fld] == rts_candidate])))

    if rts_predis_out:
        # Write classified objects before growing
        so.write_objects(rts_predis_out,
                         to_str_cols=to_str_cols,
                         overwrite=True)

    #%% Dissolve touching candidates
    rts_dissolved = dissolve_touching(
        so.objects[so.objects[so.class_fld] == rts_candidate])
    so.objects = pd.concat(
        [so.objects[so.objects[so.class_fld] != rts_candidate], rts_dissolved])

    #%% Write RTS candidates
    logger.info('Writing RTS candidates...')
    so.write_objects(rts_candidates_out,
                     to_str_cols=to_str_cols,
                     overwrite=True)

    return rts_candidates_out
示例#7
0
def main(
    dem,
    dem_prev,
    project_dir,
    config,
    aoi=None,
    image=None,
    pansh_img=None,
    skip_steps=None,
):
    # Convert to path objects
    if image is not None:
        image = Path(image)
    dem = Path(dem)
    dem_prev = Path(dem_prev)

    # %% Get configuration settings
    config = get_config(config_file=config)

    # Project config settings
    project_config = config['project']
    EPSG = project_config['EPSG']
    if aoi is None:
        aoi = Path(project_config['AOI'])
    else:
        aoi = Path(aoi)
    fill_nodata = project_config['fill_nodata']

    # Headwall and RTS config settings
    hw_config = config['headwall']
    rts_config = config['rts']
    grow_config = config['grow']

    # Preprocessing
    pansh_config = config['pansharpen']
    BITDEPTH = pansh_config['t']
    STRETCH = pansh_config['c']

    dem_deriv_config = config[DEM_DERIV]
    med_config = dem_deriv_config[MED]
    curv_config = dem_deriv_config[CURV]
    img_deriv_config = config[IMG_DERIV]
    edge_config = img_deriv_config[EDGE_EX]

    # %% Build project directory structure
    logger.info('Creating project directory structure...')
    project_dir = Path(project_dir)
    if not project_dir.exists():
        logger.info('Creating project parent directory: '
                    '{}'.format(project_dir))
        os.makedirs(project_dir)
    SCRATCH_DIR = project_dir / 'scratch'
    IMG_DIR = project_dir / 'img'
    PANSH_DIR = project_dir / 'pansh'
    NDVI_DIR = project_dir / 'ndvi'
    DEM_DIR = project_dir / 'dem'
    DEM_PREV_DIR = project_dir / dem_prev_k
    DEM_DERIV_DIR = DEM_DIR / 'deriv'
    SEG_DIR = project_dir / 'seg'
    # HW_DIR = SEG_DIR / 'headwall'
    HW_SEG_GPKG = SEG_DIR / 'headwall.gpkg'
    # RTS_DIR = SEG_DIR / 'rts'
    RTS_SEG_GPKG = SEG_DIR / 'rts.gpkg'
    # GROW_DIR = SEG_DIR / 'grow'
    GROW_SEG_GPKG = SEG_DIR / 'grow.gpkg'
    # CLASS_DIR = project_dir / 'classified'
    CLASS_GPKG = project_dir / 'classified.gpkg'

    for d in [
            SCRATCH_DIR,
            IMG_DIR,
            PANSH_DIR,
            NDVI_DIR,
            DEM_DIR,
            DEM_PREV_DIR,
            DEM_DERIV_DIR,
            SEG_DIR,
            # HW_DIR, RTS_DIR, GROW_DIR, CLASS_DIR
    ]:
        if not d.exists():
            os.makedirs(d)
    out_vec_fmt = project_config[out_vec_fmt_k]

    # %% Imagery Preprocessing
    logger.info('\n\n***PREPROCESSING***')
    # Pansharpen
    if pansh_img is None:
        if pan not in skip_steps:
            logger.info('Pansharpening: {}'.format(image.name))
            pansh_cmd = '{} {} {} -p {} -d {} -t {} -c {} ' \
                        '--skip-dem-overlap-check'.format(PANSH_PY,
                                                          image,
                                                          PANSH_DIR,
                                                          EPSG,
                                                          dem,
                                                          BITDEPTH,
                                                          STRETCH)
            run_subprocess(pansh_cmd)

        # Determine output name
        pansh_img = PANSH_DIR / '{}_{}{}{}_pansh.tif'.format(
            image.stem, bitdepth_lut[BITDEPTH], STRETCH, EPSG)
    else:
        pansh_img = Path(pansh_img)

    # NDVI
    if ndvi not in skip_steps:
        logger.info('Creating NDVI from: {}'.format(pansh_img.name))
        ndvi_cmd = '{} {}'.format(NDVI_PY, pansh_img, NDVI_DIR)
        run_subprocess(ndvi_cmd)
    # Determine NDVI name
    ndvi_img = NDVI_DIR / '{}_ndvi.tif'.format(pansh_img.stem)
    # ndvi_img = NDVI_DIR / '{}_ndvi.tif'.format(image.stem)

    # %% Clip to AOI
    # Organize inputs
    inputs = {
        img_k: pansh_img,
        ndvi_k: ndvi_img,
        dem_k: dem,
        dem_prev_k: dem_prev,
    }

    if aoi:
        logger.info('Clipping inputs to AOI: {}'.format(aoi))
        for k, r in tqdm(inputs.items()):
            # out_path = r.parent / '{}{}{}'.format(r.stem, clip_sfx,
            #                                       r.suffix)
            out_path = project_dir / k / '{}{}{}'.format(
                r.stem, clip_sfx, r.suffix)
            if clip_step not in skip_steps:
                logger.debug('Clipping input {} to AOI: {}'.format(
                    k, aoi.name))
                clip_rasters(str(aoi),
                             str(r),
                             out_path=str(out_path),
                             out_suffix='',
                             skip_srs_check=True)
            inputs[k] = out_path
    if fill_nodata:
        logger.info('Filling internal NoData gaps in sources...')
        for k, r in tqdm(inputs.items()):
            # Only fill image no data
            if k in [img_k, ndvi_k]:
                filled = inputs[k].parent / '{}_filled{}'.format(
                    inputs[k].stem, inputs[k].suffix)
                if fill_step not in skip_steps:
                    fill_internal_nodata(inputs[k], filled, str(aoi))
                inputs[k] = filled

    # %% EdgeExtraction
    edge_config[img_k] = inputs[img_k]
    edge_config[out_dir] = IMG_DIR
    edge = otb_ee.create_outname(**edge_config)
    if edge_extraction not in skip_steps:
        logger.info('Creating EdgeExtraction')
        otb_ee.otb_edge_extraction(**edge_config)
    inputs[edge_k] = edge

    # %% DEM Derivatives
    if dem_deriv not in skip_steps:
        # DEM Diff
        logger.info('Creating DEM Difference...')
        logger.info('DEM1: {}'.format(inputs[dem_k]))
        logger.info('DEM2: {}'.format(inputs[dem_prev_k]))
        diff = DEM_DERIV_DIR / 'dem_diff.tif'
        difference_dems(str(inputs[dem_k]),
                        str(inputs[dem_prev_k]),
                        out_dem=str(diff))

        # Slope
        logger.info('Creating slope...')
        slope = DEM_DERIV_DIR / '{}_slope{}'.format(dem.stem, dem.suffix)
        gdal_dem_derivative(str(inputs[dem_k]), str(slope), 'slope')
        # Ruggedness
        logger.info('Creating ruggedness index...')
        ruggedness = DEM_DERIV_DIR / '{}_rugged{}'.format(dem.stem, dem.suffix)
        gdal_dem_derivative(str(inputs[dem_k]), str(ruggedness), 'TRI')

        # MED
        logger.info('Creating Maximum Elevation Deviation...')
        med = wbt_med(str(inputs[dem_k]),
                      out_dir=str(DEM_DERIV_DIR),
                      **med_config)

        # Curvature
        logger.info('Creating profile curvature...')
        curvature = wbt_curvature(str(inputs[dem_k]),
                                  out_dir=str(DEM_DERIV_DIR),
                                  **curv_config)

        # Surface Area Ratio
        logger.info('Creating Surface Area Ratio...')
        sar = wbt_sar(str(inputs[dem_k]), out_dir=str(DEM_DERIV_DIR))

        inputs[med_k] = med
        inputs[curv_k] = curvature
        inputs[slope_k] = slope
        inputs[rugged_k] = ruggedness
        inputs[diff_k] = diff
        inputs[sar_k] = sar

    # %% SEGMENTATION PREPROCESSING - Segment, calculate zonal statistics
    # %%
    # HEADWALL
    logger.info('\n\n***HEADWALL***')
    # Segmentation
    hw_config[seg][params][img_k] = inputs[img_k]
    # hw_config[seg][params][out_dir] = HW_DIR
    hw_seg_out = hw_config[seg][params][out_seg] = str(HW_SEG_GPKG /
                                                       'headwall_seg')
    if hw_config[seg][alg] == grm:
        logger.info('Segmenting subobjects (headwalls)...')
        if hw_seg not in skip_steps:
            hw_objects = otb_grm.otb_grm(drop_smaller=0.5,
                                         **hw_config[seg][params])
        else:
            hw_objects = otb_grm.create_outname(**hw_config[seg][params],
                                                name_only=True)
            logger.debug('Using provided headwall segmentation: '
                         '{}'.format(
                             Path(hw_config[seg][params][out_seg]).relative_to(
                                 project_dir)))

    # %% Cleanup
    # Create path to write cleaned objects to
    # hw_objects = Path(hw_objects)
    # cleaned_objects_out = str(hw_objects.parent / '{}{}{}'.format(
    #     hw_objects.stem, clean_sfx, hw_objects.suffix))
    cleaned_objects_out = hw_seg_out + '_cleaned'

    if hw_clean not in skip_steps:
        if hw_config[cleanup][cleanup]:
            logger.info('Cleaning up subobjects...')
            cleanup_params = hw_config[cleanup][params]
            cleanup_params[mask_on] = str(inputs[dem_k])
            # hw_objects = Path(hw_objects)
            hw_objects = cleanup_objects(input_objects=hw_seg_out,
                                         out_objects=cleaned_objects_out,
                                         **cleanup_params)
    else:
        logger.debug('Using provided cleaned headwall objects'
                     '{}: '.format(
                         Path(cleaned_objects_out).relative_to(project_dir)))
        # hw_objects = cleaned_objects_out

    # %% Zonal Stats
    logger.info('Calculating zonal statistics on headwall objects...')
    # hw_objects_path = Path(hw_objects)
    # hw_zs_out_path = '{}_zs{}'.format(
    #     hw_objects_path.parent / hw_objects_path.stem, hw_objects_path.suffix)
    hw_zs_out = cleaned_objects_out + '_zs'

    if hw_zs not in skip_steps:
        # Calculate zonal stats
        zonal_stats_inputs = {
            k: {
                'path': v,
                'stats': hw_config[zonal_stats][zs_stats]
            }
            for k, v in inputs.items()
            if k in hw_config[zonal_stats][zs_rasters]
        }
        if bands_k in hw_config[zonal_stats].keys():
            zonal_stats_inputs[img_k][bands_k] = hw_config[zonal_stats][
                bands_k]
        hw_objects = calc_zonal_stats(shp=cleaned_objects_out,
                                      rasters=zonal_stats_inputs,
                                      out_path=hw_zs_out)
    else:
        logger.debug('Using provided headwall objects with zonal stats: '
                     '{}'.format(Path(hw_zs_out).relative_to(project_dir)))
        # hw_objects = hw_zs_out

    # %%
    # RTS
    logger.info('\n\n***RTS***')
    # Naming
    rts_config[seg][params][img_k] = inputs[img_k]
    # rts_config[seg][params][out_dir] = RTS_DIR
    rts_seg_out = rts_config[seg][params][out_seg] = str(RTS_SEG_GPKG /
                                                         'rts_seg')

    # Segmentation
    if rts_config[seg][alg] == grm:
        logger.info('Segmenting superobjects (RTS)...')
        if rts_seg not in skip_steps:
            rts_objects = otb_grm.otb_grm(drop_smaller=0.5,
                                          **rts_config[seg][params])
        else:
            # rts_objects = otb_grm.create_outname(**rts_config[seg][params],
            #                              name_only=True)
            logger.debug('Using provided RTS seg: '
                         '{}'.format(
                             Path(rts_config[seg][params]
                                  [out_seg]).relative_to(project_dir)))
        # rts_objects = Path(rts_objects)

    # %% Cleanup
    # cleaned_objects_out = str(rts_objects.parent / '{}_cln{}'.format(
    #     rts_objects.stem, rts_objects.suffix))
    cleaned_objects_out = rts_seg_out + '_cleaned'

    if rts_clean not in skip_steps:
        if rts_config[cleanup][cleanup]:
            logger.info('Cleaning up objects...')
            cleanup_params = rts_config[cleanup][params]
            cleanup_params[mask_on] = str(inputs[dem_k])
            # rts_objects = Path(rts_objects)
            rts_objects = cleanup_objects(input_objects=rts_seg_out,
                                          out_objects=cleaned_objects_out,
                                          **cleanup_params)
    else:
        logger.debug('Using provided cleaned RTS objects: '
                     '{}'.format(
                         Path(cleaned_objects_out).relative_to(project_dir)))
        # rts_objects = cleaned_objects_out

    # %% Zonal Stats
    logger.info('Calculating zonal statistics on super objects...')
    # rts_objects_path = Path(rts_objects)
    # rts_zs_out_path = '{}_zs{}'.format(rts_objects_path.parent /
    #                                    rts_objects_path.stem,
    #                                    rts_objects_path.suffix)
    rts_zs_out = cleaned_objects_out + '_zs'

    if rts_zs not in skip_steps:
        # Calculate zonal_stats
        zonal_stats_inputs = {
            k: {
                'path': v,
                'stats': rts_config[zonal_stats][zs_stats]
            }
            for k, v in inputs.items()
            if k in rts_config[zonal_stats][zs_rasters]
        }
        rts_objects = calc_zonal_stats(shp=cleaned_objects_out,
                                       rasters=zonal_stats_inputs,
                                       out_path=rts_zs_out)
    else:
        logger.debug('Using provided RTS zonal stats objects: '
                     '{}'.format(Path(rts_zs_out).relative_to(project_dir)))
        # rts_objects = rts_zs_out_path

    # %% CLASSIFICATION
    if hw_config[classification_k][hw_class_out_k]:
        # hw_class_out = CLASS_DIR / '{}{}'.format(hw_class_out_k, out_vec_fmt)
        hw_class_out = CLASS_GPKG / 'headwalls'
    if hw_config[classification_k][hw_class_out_cent_k]:
        # hw_class_out_centroid = CLASS_DIR / '{}_cent{}'.format(hw_class_out_k,
        #                                                        out_vec_fmt)
        hw_class_out_centroid = CLASS_GPKG / 'headwall_centers'

    # Pass path to classified headwall objects if using previously classified
    if hw_class in skip_steps:
        hw_candidates_in = hw_class_out
    else:
        hw_candidates_in = None

    if rts_config[classification_k][rts_predis_out_k]:
        # rts_predis_out = CLASS_DIR / '{}{}'.format(rts_predis_out_k, out_vec_fmt)
        rts_predis_out = CLASS_GPKG / 'rts_predissolve'
    if rts_config[classification_k][rts_class_out_k]:
        # rts_class_out = CLASS_DIR / '{}{}'.format(rts_class_out_k, out_vec_fmt)
        rts_class_out = CLASS_GPKG / 'rts_candidates'

    if rts_class not in skip_steps:
        logger.info('Classifying RTS...')
        rts_objects = classify_rts(
            sub_objects_path=hw_zs_out,
            super_objects_path=rts_zs_out,
            headwall_candidates_out=hw_class_out,
            headwall_candidates_centroid_out=hw_class_out_centroid,
            rts_predis_out=rts_predis_out,
            rts_candidates_out=rts_class_out,
            aoi_path=None,
            headwall_candidates_in=hw_candidates_in,
            aoi=aoi)
    else:
        logger.debug('Using provided classified RTS objects: '
                     '{}'.format(Path(rts_class_out).relative_to(project_dir)))
        rts_objects = rts_class_out

    #%% GROW OBJECTS
    logger.info('\n\n***GROWING***')
    logger.info('Creating grow subobjects..')
    # Segment AOI into simple grow
    grow_config[seg][params][img_k] = inputs[img_k]
    # grow_config[seg][params][out_dir] = GROW_DIR
    grow_seg_out = grow_config[seg][params][out_seg] = str(GROW_SEG_GPKG /
                                                           'grow')

    if grow_seg not in skip_steps:
        grow = otb_grm.otb_grm(drop_smaller=0.5, **grow_config[seg][params])
    else:
        grow = otb_grm.create_outname(**grow_config[seg][params],
                                      name_only=True)
        logger.debug('Using provided grow objects: '
                     '{}'.format(Path(grow).relative_to(project_dir)))

    # Cleanup
    # grow = Path(grow)
    # cleaned_grow = str(grow.parent / '{}_cln{}'.format(grow.stem, grow.suffix))
    cleaned_grow_out = grow_config[seg][params][out_seg] + '_cleaned'

    if grow_clean not in skip_steps:
        if grow_config[cleanup][cleanup]:
            logger.info('Cleaning up objects...')
            cleanup_params = grow_config[cleanup][params]
            cleanup_params[mask_on] = str(inputs[dem_k])
            cleaned_grow = cleanup_objects(input_objects=grow_seg_out,
                                           out_objects=cleaned_grow_out,
                                           **cleanup_params)

    grow_zs_out = cleaned_grow_out + '_zs'
    if grow_zs not in skip_steps:
        logger.info('Merging RTS candidates into grow objects...')
        # Load small objects
        logger.debug(cleaned_grow_out)
        # so = gpd.read_file(cleaned_grow_out)
        so = read_vec(cleaned_grow_out)

        # Burn rts in, including class name
        logger.debug(rts_objects)
        # r = gpd.read_file(rts_objects)
        r = read_vec(rts_objects)
        r = r[r[class_fld] == rts_candidate][[class_fld, r.geometry.name]]

        # Erase subobjects under RTS candidates
        diff = gpd.overlay(so, r, how='difference')
        # Merge RTS candidates back in
        merged = pd.concat([diff, r])
        # merged_out = SEG_DIR / GROW_DIR / 'merged.shp'
        merged_out = GROW_SEG_GPKG / 'merged'
        write_gdf(merged, merged_out)

        # Zonal Stats
        zonal_stats_inputs = {
            k: {
                'path': v,
                'stats': grow_config[zonal_stats][zs_stats]
            }
            for k, v in inputs.items()
            if k in grow_config[zonal_stats][zs_rasters]
        }

        logger.info('Calculating zonal statistics on grow objects...')
        logger.debug('Computing zonal statistics on: '
                     '{}'.format(zonal_stats_inputs.keys()))
        grow = calc_zonal_stats(shp=merged_out,
                                rasters=zonal_stats_inputs,
                                out_path=str(grow_zs_out))

    # Do growing
    logger.info('Growing RTS objects into subobjects...')

    # Grow from rts
    # grow_objects = ImageObjects(grow_zs_out_path,
    #                             value_fields=zonal_stats_inputs)
    # # TODO: Remove after converting to use gpkg. This is just because of the
    # #  shapefile field size limit
    # grow_objects.objects.rename(columns={'ruggedness': 'ruggedness_mean'},
    #                             inplace=True)
    #
    # rts_objects = ImageObjects(rts_objects)
    # grown = grow_rts_candidates(rts_objects, grow_objects)

    grown, grow_candidates = grow_rts_simple(grow_zs_out)

    grow_candidates_out = CLASS_GPKG / 'grow_candidates'
    # logger.info('Writing grow objects to file: {}'.format(CLASS_DIR / 'grow_candidates.shp'))
    logger.info('Writing grow objects to file: {}'.format(grow_candidates_out))
    write_gdf(grow_candidates.objects, grow_candidates_out)

    # n = datetime.now().strftime('%Y%b%d_%H%M%S').lower()

    # rts_classified = CLASS_DIR / 'RTS.shp'
    rts_classified = CLASS_GPKG / 'RTS'
    logger.info('Writing classfied RTS features: {}'.format(rts_classified))
    write_gdf(grown, rts_classified)

    logger.info('Done')
示例#8
0
def clip_rasters(shp_p, rasters, out_path=None, out_dir=None, out_suffix='_clip',
                 out_prj_shp=None, raster_ext=None, move_meta=False, 
                 in_mem=False, skip_srs_check=False, overwrite=False):
    """
    Take a list of rasters and warps (clips) them to the shapefile feature
    bounding box.
    rasters : LIST or STR
        List of rasters to clip, or if STR, path to single raster.
    out_prj_shp : os.path.abspath
        Path to create the projected shapefile if necessary to match raster prj
    """
    # TODO: Fix permission error if out_prj_shp not supplied -- create in-mem
    #  OGR?
    # TODO: Add support for other vector formats -- create read_vec()??
    # Use in memory directory if specified
    if out_dir is None:
        in_mem = True
    if in_mem:
        out_dir = r'/vsimem'

    # Check that spatial references match, if not reproject (assumes all rasters have same projection)
    # TODO: support different extension (slow to check all of them in the loop below)
    # Check if list of rasters provided or if single raster
    if isinstance(rasters, list):
        check_raster = rasters[0]
    else:
        check_raster = rasters
        rasters = [rasters]

    if not skip_srs_check:
        logger.debug('Checking spatial reference match:\n{}\n{}'.format(shp_p, check_raster))
        sr_match = check_sr(shp_p, check_raster)
        if not sr_match:
            logger.debug('Spatial references do not match. Reprojecting to AOI...')
            if not out_prj_shp:
                out_prj_shp = shp_p.replace('.shp', '_prj.shp')
            shp_p = ogr_reproject(shp_p,
                                  to_sr=get_raster_sr(check_raster),
                                  output_shp=out_prj_shp)

    # shp = gpd.read_file(shp_p)
    _shp_driver, shp_layer = detect_ogr_driver(shp_p)
    shp = read_vec(shp_p)
    if len(shp) > 1:
        logger.debug('Dissolving clipping shape with multiple features...')
        shp['dissolve'] = 1
        shp = shp.dissolve(by='dissolve')
        shp_p = r'/vsimem/clip_shp_dissolve.shp'
        shp.to_file(shp_p)

    # Do the 'warping' / clipping
    warped = []
    for raster_p in rasters:
        # TODO: Handle this with platform.sys and pathlib.Path objects
        raster_p = raster_p.replace(r'\\', os.sep)
        raster_p = raster_p.replace(r'/', os.sep)

        # Create out_path if not provided
        if not out_path:
            if not out_dir:
                logger.debug('NO OUT_DIR')
            # Create outpath
            raster_out_name = '{}{}.tif'.format(
                os.path.basename(raster_p).split('.')[0], out_suffix)
            raster_out_path = os.path.join(out_dir, raster_out_name)
        else:
            raster_out_path = out_path

        # Clip to shape
        logger.info('Clipping:\n{}\n\t---> '
                     '{}'.format(os.path.basename(raster_p),
                                 raster_out_path))
        if os.path.exists(raster_out_path) and not overwrite:
            logger.warning('Outpath exists, skipping: '
                           '{}'.format(raster_out_path))
            pass
        else:
            raster_ds = gdal.Open(raster_p, gdal.GA_ReadOnly)
            x_res = raster_ds.GetGeoTransform()[1]
            y_res = raster_ds.GetGeoTransform()[5]
            if shp_layer is not None:
                warp_options = gdal.WarpOptions(cutlineDSName=Path(shp_p).parent,
                                                cutlineLayer=shp_layer,
                                                cropToCutline=True,
                                                targetAlignedPixels=True,
                                                xRes=x_res,
                                                yRes=y_res)
            else:
                warp_options = gdal.WarpOptions(cutlineDSName=shp_p,
                                                cropToCutline=True,
                                                targetAlignedPixels=True,
                                                xRes=x_res,
                                                yRes=y_res)
            gdal.Warp(raster_out_path, raster_ds, options=warp_options)
            # Close the raster
            raster_ds = None
            logger.debug('Clipped raster created at {}'.format(raster_out_path))
            # Add clipped raster path to list of clipped rasters to return
            warped.append(raster_out_path)
        # Move meta-data files if specified
        if move_meta:
            logger.debug('Moving metadata files to clip destination...')
            move_meta_files(raster_p, out_dir, raster_ext=raster_ext)

    # Remove projected shp
    if in_mem is True:
        remove_shp(out_prj_shp)

    # If only one raster provided, just return the single path as str
    if len(warped) == 1:
        warped = warped[0]

    return warped