示例#1
0
    def split_raster(rs, split_shp, field_name, temp_dir):
        """Split raster by given shapefile and field name.

        Args:
            rs: origin raster file.
            split_shp: boundary (ESRI Shapefile) used to spilt raster.
            field_name: field name identify the spilt value.
            temp_dir: directory to store the spilt rasters.
        """
        UtilClass.rmmkdir(temp_dir)
        ds = ogr_Open(split_shp)
        lyr = ds.GetLayer(0)
        lyr.ResetReading()
        ft = lyr.GetNextFeature()
        while ft:
            cur_field_name = ft.GetFieldAsString(field_name)
            for r in rs:
                cur_file_name = r.split(os.sep)[-1]
                outraster = temp_dir + os.sep + \
                            cur_file_name.replace('.tif', '_%s.tif' %
                                                  cur_field_name.replace(' ', '_'))
                subprocess.call(['gdalwarp', r, outraster, '-cutline', split_shp,
                                 '-crop_to_cutline', '-cwhere',
                                 "'%s'='%s'" % (field_name, cur_field_name), '-dstnodata',
                                 '-9999'])
            ft = lyr.GetNextFeature()
        ds = None
示例#2
0
文件: raster.py 项目: crazyzlj/PyGeoC
    def split_raster(rs, split_shp, field_name, temp_dir):
        """Split raster by given shapefile and field name.

        Args:
            rs: origin raster file.
            split_shp: boundary (ESRI Shapefile) used to spilt raster.
            field_name: field name identify the spilt value.
            temp_dir: directory to store the spilt rasters.
        """
        UtilClass.rmmkdir(temp_dir)
        ds = ogr_Open(split_shp)
        lyr = ds.GetLayer(0)
        lyr.ResetReading()
        ft = lyr.GetNextFeature()
        while ft:
            cur_field_name = ft.GetFieldAsString(field_name)
            for r in rs:
                cur_file_name = r.split(os.sep)[-1]
                outraster = temp_dir + os.sep + \
                            cur_file_name.replace('.tif', '_%s.tif' %
                                                  cur_field_name.replace(' ', '_'))
                subprocess.call(['gdalwarp', r, outraster, '-cutline', split_shp,
                                 '-crop_to_cutline', '-cwhere',
                                 "'%s'='%s'" % (field_name, cur_field_name), '-dstnodata',
                                 '-9999'])
            ft = lyr.GetNextFeature()
        ds = None
示例#3
0
 def ogrwkt2shapely(input_shape, id_field):
     """Return shape objects list and ids list"""
     # CAUTION, IMPORTANT
     # Because shapely is dependent on sqlite, and the version is not consistent
     #    with GDAL executable (e.g., located in C:\GDAL_x64\bin), thus the shapely
     #    must be locally imported here.
     from shapely.wkt import loads as shapely_loads
     shapely_objects = list()
     id_list = list()
     # print(input_shape)
     shp = ogr_Open(input_shape)
     if shp is None:
         raise RuntimeError('The input ESRI Shapefile: %s is not existed or has '
                            'no read permission!' % input_shape)
     lyr = shp.GetLayer()
     for n in range(0, lyr.GetFeatureCount()):
         feat = lyr.GetFeature(n)
         # This function may print Failed `CDLL(/opt/local/lib/libgeos_c.dylib)` in macOS
         # Don't worry about that!
         wkt_feat = shapely_loads(feat.geometry().ExportToWkt())
         shapely_objects.append(wkt_feat)
         if isinstance(id_field, text_type):
             id_field = str(id_field)
         id_index = feat.GetFieldIndex(id_field)
         id_list.append(feat.GetField(id_index))
     return shapely_objects, id_list
 def ogrwkt2shapely(input_shape, id_field):
     """Return shape objects list and ids list"""
     # CAUTION, IMPORTANT
     # Because shapely is dependent on sqlite, and the version is not consistent
     #    with GDAL executable (e.g., located in C:\GDAL_x64\bin), thus the shapely
     #    must be locally imported here.
     from shapely.wkt import loads as shapely_loads
     shapely_objects = list()
     id_list = list()
     # print(input_shape)
     shp = ogr_Open(input_shape)
     if shp is None:
         raise RuntimeError(
             'The input ESRI Shapefile: %s is not existed or has '
             'no read permission!' % input_shape)
     lyr = shp.GetLayer()
     for n in range(0, lyr.GetFeatureCount()):
         feat = lyr.GetFeature(n)
         # This function may print Failed `CDLL(/opt/local/lib/libgeos_c.dylib)` in macOS
         # Don't worry about that!
         wkt_feat = shapely_loads(feat.geometry().ExportToWkt())
         shapely_objects.append(wkt_feat)
         if isinstance(id_field, text_type):
             id_field = str(id_field)
         id_index = feat.GetFieldIndex(id_field)
         id_list.append(feat.GetField(id_index))
     return shapely_objects, id_list
    def add_group_field(shp_file, subbasin_field_name, group_metis_dict):
        """add group information to subbasin ESRI shapefile

        Args:
            shp_file: Subbasin Shapefile
            subbasin_field_name: field name of subbasin
            group_metis_dict: returned by func`metis_partition`
        """
        if not group_metis_dict:
            return
        ds_reach = ogr_Open(shp_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        icode = layer_def.GetFieldIndex(str(subbasin_field_name))
        igrp = layer_def.GetFieldIndex(str(ImportReaches2Mongo._GROUP))
        ikgrp = layer_def.GetFieldIndex(str(ImportReaches2Mongo._KMETIS))
        ipgrp = layer_def.GetFieldIndex(str(ImportReaches2Mongo._PMETIS))

        if igrp < 0:
            new_field = ogr_FieldDefn(str(ImportReaches2Mongo._GROUP),
                                      OFTInteger)
            layer_reach.CreateField(new_field)
        if ikgrp < 0:
            new_field = ogr_FieldDefn(str(ImportReaches2Mongo._KMETIS),
                                      OFTInteger)
            layer_reach.CreateField(new_field)
        if ipgrp < 0:
            new_field = ogr_FieldDefn(str(ImportReaches2Mongo._PMETIS),
                                      OFTInteger)
            layer_reach.CreateField(new_field)

        ftmap = dict()
        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            tmpid = ft.GetFieldAsInteger(icode)
            ftmap[tmpid] = ft
            ft = layer_reach.GetNextFeature()

        groups = group_metis_dict[1]['group']
        for i, n in enumerate(groups):
            for node, d in group_metis_dict.items():
                ftmap[node].SetField(str(ImportReaches2Mongo._GROUP), n)
                ftmap[node].SetField(str(ImportReaches2Mongo._KMETIS),
                                     d['kmetis'][i])
                ftmap[node].SetField(str(ImportReaches2Mongo._PMETIS),
                                     d['pmetis'][i])
                layer_reach.SetFeature(ftmap[node])
            # copy the reach file to new file
            prefix = os.path.splitext(shp_file)[0]
            dstfile = prefix + "_" + str(n) + ".shp"
            FileClass.copy_files(shp_file, dstfile)

        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
示例#6
0
    def add_channel_width_to_shp(reach_shp_file,
                                 stream_link_file,
                                 width_data,
                                 default_depth=1.5):
        """Add channel/reach width and default depth to ESRI shapefile"""
        stream_link = RasterUtilClass.read_raster(stream_link_file)
        n_rows = stream_link.nRows
        n_cols = stream_link.nCols
        nodata_value = stream_link.noDataValue
        data_stream = stream_link.data

        ch_width_dic = dict()
        ch_num_dic = dict()

        for i in range(n_rows):
            for j in range(n_cols):
                if abs(data_stream[i][j] - nodata_value) > UTIL_ZERO:
                    tmpid = int(data_stream[i][j])
                    ch_num_dic.setdefault(tmpid, 0)
                    ch_width_dic.setdefault(tmpid, 0)
                    ch_num_dic[tmpid] += 1
                    ch_width_dic[tmpid] += width_data[i][j]

        for k in ch_num_dic:
            ch_width_dic[k] /= ch_num_dic[k]

        # add channel width_data field to reach shp file
        ds_reach = ogr_Open(reach_shp_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        i_link = layer_def.GetFieldIndex(ImportReaches2Mongo._LINKNO)
        i_width = layer_def.GetFieldIndex(ImportReaches2Mongo._WIDTH)
        i_depth = layer_def.GetFieldIndex(ImportReaches2Mongo._DEPTH)
        if i_width < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._WIDTH, OFTReal)
            layer_reach.CreateField(new_field)
        if i_depth < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._DEPTH, OFTReal)
            layer_reach.CreateField(new_field)
            # grid_code:feature map
        # ftmap = {}
        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            tmpid = ft.GetFieldAsInteger(i_link)
            w = 1
            if tmpid in list(ch_width_dic.keys()):
                w = ch_width_dic[tmpid]
            ft.SetField(ImportReaches2Mongo._WIDTH, w)
            ft.SetField(ImportReaches2Mongo._DEPTH, default_depth)
            layer_reach.SetFeature(ft)
            ft = layer_reach.GetNextFeature()

        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
示例#7
0
    def add_group_field(shp_file, subbasin_field_name, n, group_kmetis,
                        group_pmetis, ns):
        """add group information to subbasin ESRI shapefile
        Args:
            shp_file: Subbasin Shapefile
            subbasin_field_name: field name of subbasin
            n: divide number
            group_kmetis: kmetis
            group_pmetis: pmetis
            ns: a list of the nodes in the graph

        Returns:
            group_dic: group dict
            group_dic_pmetis: pmetis dict
        """
        ds_reach = ogr_Open(shp_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        i_code = layer_def.GetFieldIndex(subbasin_field_name)
        i_group = layer_def.GetFieldIndex(ImportReaches2Mongo._GROUP)
        i_group_pmetis = layer_def.GetFieldIndex(ImportReaches2Mongo._PMETIS)
        if i_group < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._GROUP, OFTInteger)
            layer_reach.CreateField(new_field)
        if i_group_pmetis < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._PMETIS, OFTInteger)
            layer_reach.CreateField(new_field)
            # grid_code:feature map
        ftmap = dict()
        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            tmpid = ft.GetFieldAsInteger(i_code)
            ftmap[tmpid] = ft
            ft = layer_reach.GetNextFeature()

        group_dic = dict()
        group_dic_pmetis = dict()
        i = 0
        for node in ns:
            group_dic[node] = group_kmetis[i]
            group_dic_pmetis[node] = group_pmetis[i]
            ftmap[node].SetField(ImportReaches2Mongo._GROUP, group_kmetis[i])
            ftmap[node].SetField(ImportReaches2Mongo._PMETIS, group_pmetis[i])
            layer_reach.SetFeature(ftmap[node])
            i += 1

        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach

        # copy the reach file to new file
        prefix = os.path.splitext(shp_file)[0]
        dstfile = prefix + "_" + str(n) + ".shp"
        FileClass.copy_files(shp_file, dstfile)
        return group_dic, group_dic_pmetis
示例#8
0
    def add_channel_width_to_shp(reach_shp_file, stream_link_file,
                                 width_data, default_depth=1.5):
        """Add channel/reach width and default depth to ESRI shapefile"""
        stream_link = RasterUtilClass.read_raster(stream_link_file)
        n_rows = stream_link.nRows
        n_cols = stream_link.nCols
        nodata_value = stream_link.noDataValue
        data_stream = stream_link.data

        ch_width_dic = dict()
        ch_num_dic = dict()

        for i in range(n_rows):
            for j in range(n_cols):
                if abs(data_stream[i][j] - nodata_value) > UTIL_ZERO:
                    tmpid = int(data_stream[i][j])
                    ch_num_dic.setdefault(tmpid, 0)
                    ch_width_dic.setdefault(tmpid, 0)
                    ch_num_dic[tmpid] += 1
                    ch_width_dic[tmpid] += width_data[i][j]

        for k in ch_num_dic:
            ch_width_dic[k] /= ch_num_dic[k]

        # add channel width_data field to reach shp file
        ds_reach = ogr_Open(reach_shp_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        i_link = layer_def.GetFieldIndex(ImportReaches2Mongo._LINKNO)
        i_width = layer_def.GetFieldIndex(ImportReaches2Mongo._WIDTH)
        i_depth = layer_def.GetFieldIndex(ImportReaches2Mongo._DEPTH)
        if i_width < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._WIDTH, OFTReal)
            layer_reach.CreateField(new_field)
        if i_depth < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._DEPTH, OFTReal)
            layer_reach.CreateField(new_field)
            # grid_code:feature map
        # ftmap = {}
        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            tmpid = ft.GetFieldAsInteger(i_link)
            w = 1
            if tmpid in list(ch_width_dic.keys()):
                w = ch_width_dic[tmpid]
            ft.SetField(ImportReaches2Mongo._WIDTH, w)
            ft.SetField(ImportReaches2Mongo._DEPTH, default_depth)
            layer_reach.SetFeature(ft)
            ft = layer_reach.GetNextFeature()

        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
    def read_reach_downstream_info(reach_shp, is_taudem=True):
        # type: (AnyStr, bool) -> Dict[int, Dict[AnyStr, Union[int, float]]]
        """Read information of subbasin.
        Args:
            reach_shp: reach ESRI shapefile.
            is_taudem: is TauDEM or not, true is default.

        Returns:
            rch_dict: {stream ID: {'downstream': downstreamID,
                                   'depth': depth value,
                                   'slope': slope value,
                                   'width': width value,
                                   'length': length value}
                                  }
        """
        rch_dict = dict()

        ds_reach = ogr_Open(reach_shp)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        if not is_taudem:  # For ArcSWAT
            ImportReaches2Mongo._LINKNO = 'FROM_NODE'
            ImportReaches2Mongo._DSLINKNO = 'TO_NODE'
            ImportReaches2Mongo._SLOPE = 'Slo2'  # TauDEM: Slope (tan); ArcSWAT: Slo2 (100*tan)
            ImportReaches2Mongo._LENGTH = 'Len2'  # TauDEM: Length; ArcSWAT: Len2
        ifrom = layer_def.GetFieldIndex(str(ImportReaches2Mongo._LINKNO))
        ito = layer_def.GetFieldIndex(str(ImportReaches2Mongo._DSLINKNO))
        idph = layer_def.GetFieldIndex(str(ImportReaches2Mongo._DEPTH))
        islp = layer_def.GetFieldIndex(str('Slope'))
        iwth = layer_def.GetFieldIndex(str(ImportReaches2Mongo._WIDTH))
        ilen = layer_def.GetFieldIndex(str('Length'))

        ft = layer_reach.GetNextFeature()
        while ft is not None:
            nfrom = ft.GetFieldAsInteger(ifrom)
            nto = ft.GetFieldAsInteger(ito)
            rch_dict[nfrom] = {
                'downstream':
                nto,
                'depth':
                ft.GetFieldAsDouble(idph) if idph > 0. else 1.5,
                'slope':
                ft.GetFieldAsDouble(islp) if islp > -1
                and ft.GetFieldAsDouble(islp) > MINI_SLOPE else MINI_SLOPE,
                'width':
                ft.GetFieldAsDouble(iwth) if iwth > 0. else 5.,
                'length':
                ft.GetFieldAsDouble(ilen)
            }

            ft = layer_reach.GetNextFeature()

        return rch_dict
    def add_group_field(shp_file, subbasin_field_name, group_metis_dict):
        """add group information to subbasin ESRI shapefile

        Args:
            shp_file: Subbasin Shapefile
            subbasin_field_name: field name of subbasin
            group_metis_dict: returned by func`metis_partition`
        """
        if not group_metis_dict:
            return
        ds_reach = ogr_Open(shp_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        icode = layer_def.GetFieldIndex(subbasin_field_name)
        igrp = layer_def.GetFieldIndex(ImportReaches2Mongo._GROUP)
        ikgrp = layer_def.GetFieldIndex(ImportReaches2Mongo._KMETIS)
        ipgrp = layer_def.GetFieldIndex(ImportReaches2Mongo._PMETIS)

        if igrp < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._GROUP, OFTInteger)
            layer_reach.CreateField(new_field)
        if ikgrp < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._KMETIS, OFTInteger)
            layer_reach.CreateField(new_field)
        if ipgrp < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._PMETIS, OFTInteger)
            layer_reach.CreateField(new_field)

        ftmap = dict()
        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            tmpid = ft.GetFieldAsInteger(icode)
            ftmap[tmpid] = ft
            ft = layer_reach.GetNextFeature()

        groups = group_metis_dict[1]['group']
        for i, n in enumerate(groups):
            for node, d in group_metis_dict.items():
                ftmap[node].SetField(ImportReaches2Mongo._GROUP, n)
                ftmap[node].SetField(ImportReaches2Mongo._KMETIS, d['kmetis'][i])
                ftmap[node].SetField(ImportReaches2Mongo._PMETIS, d['pmetis'][i])
                layer_reach.SetFeature(ftmap[node])
            # copy the reach file to new file
            prefix = os.path.splitext(shp_file)[0]
            dstfile = prefix + "_" + str(n) + ".shp"
            FileClass.copy_files(shp_file, dstfile)

        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
示例#11
0
 def ogrwkt2shapely(input_shape, id_field):
     """Return shape objects list and ids list"""
     from shapely.wkt import loads as shapely_loads
     shapely_objects = []
     id_list = []
     # print input_shape
     shp = ogr_Open(input_shape)
     lyr = shp.GetLayer()
     for n in range(0, lyr.GetFeatureCount()):
         feat = lyr.GetFeature(n)
         wkt_feat = shapely_loads(feat.geometry().ExportToWkt())
         shapely_objects.append(wkt_feat)
         if isinstance(id_field, unicode):
             id_field = id_field.encode()
         id_index = feat.GetFieldIndex(id_field)
         id_list.append(feat.GetField(id_index))
     return shapely_objects, id_list
    def read_reach_downstream_info(reach_shp, is_taudem=True):
        """Read information of subbasin.
        Args:
            reach_shp: reach ESRI shapefile.
            is_taudem: is TauDEM or not, true is default.

        Returns:
            rch_dict: {stream ID: {'downstream': downstreamID,
                                   'depth': depth value,
                                   'slope': slope value,
                                   'width': width value,
                                   'length': length value}
                                  }
        """
        rch_dict = dict()

        ds_reach = ogr_Open(reach_shp)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        if not is_taudem:  # For ArcSWAT
            ImportReaches2Mongo._LINKNO = 'FROM_NODE'
            ImportReaches2Mongo._DSLINKNO = 'TO_NODE'
            ImportReaches2Mongo._SLOPE = 'Slo2'  # TauDEM: Slope (tan); ArcSWAT: Slo2 (100*tan)
            ImportReaches2Mongo._LENGTH = 'Len2'  # TauDEM: Length; ArcSWAT: Len2
        ifrom = layer_def.GetFieldIndex(ImportReaches2Mongo._LINKNO)
        ito = layer_def.GetFieldIndex(ImportReaches2Mongo._DSLINKNO)
        idph = layer_def.GetFieldIndex(ImportReaches2Mongo._DEPTH)
        islp = layer_def.GetFieldIndex('Slope')
        iwth = layer_def.GetFieldIndex(ImportReaches2Mongo._WIDTH)
        ilen = layer_def.GetFieldIndex('Length')

        ft = layer_reach.GetNextFeature()
        while ft is not None:
            nfrom = ft.GetFieldAsInteger(ifrom)
            nto = ft.GetFieldAsInteger(ito)
            rch_dict[nfrom] = {'downstream': nto,
                               'depth': ft.GetFieldAsDouble(idph) if idph > 0. else 1.5,
                               'slope': ft.GetFieldAsDouble(islp)
                               if islp > -1 and ft.GetFieldAsDouble(islp) > MINI_SLOPE
                               else MINI_SLOPE,
                               'width': ft.GetFieldAsDouble(iwth) if iwth > 0. else 5.,
                               'length': ft.GetFieldAsDouble(ilen)}

            ft = layer_reach.GetNextFeature()

        return rch_dict
示例#13
0
    def serialize_streamnet(streamnet_file, output_reach_file):
        """Eliminate reach with zero length and return the reach ID map.
        Args:
            streamnet_file: original stream net ESRI shapefile
            output_reach_file: serialized stream net, ESRI shapefile

        Returns:
            id pairs {origin: newly assigned}
        """
        FileClass.copy_files(streamnet_file, output_reach_file)
        ds_reach = ogr_Open(output_reach_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        i_link = layer_def.GetFieldIndex(FLD_LINKNO)
        i_link_downslope = layer_def.GetFieldIndex(FLD_DSLINKNO)
        i_len = layer_def.GetFieldIndex(REACH_LENGTH)

        old_id_list = []
        # there are some reaches with zero length.
        # this program will remove these zero-length reaches
        # output_dic is used to store the downstream reaches of these zero-length
        # reaches
        output_dic = {}
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            link_id = ft.GetFieldAsInteger(i_link)
            reach_len = ft.GetFieldAsDouble(i_len)
            if link_id not in old_id_list:
                if reach_len < DELTA:
                    downstream_id = ft.GetFieldAsInteger(i_link_downslope)
                    output_dic[link_id] = downstream_id
                else:
                    old_id_list.append(link_id)

            ft = layer_reach.GetNextFeature()
        old_id_list.sort()

        id_map = {}
        for i, old_id in enumerate(old_id_list):
            id_map[old_id] = i + 1
        # print(id_map)
        # change old ID to new ID
        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            link_id = ft.GetFieldAsInteger(i_link)
            if link_id not in id_map:
                layer_reach.DeleteFeature(ft.GetFID())
                ft = layer_reach.GetNextFeature()
                continue

            ds_id = ft.GetFieldAsInteger(i_link_downslope)
            ds_id = output_dic.get(ds_id, ds_id)
            ds_id = output_dic.get(ds_id, ds_id)

            ft.SetField(FLD_LINKNO, id_map[link_id])
            if ds_id in id_map:
                ft.SetField(FLD_DSLINKNO, id_map[ds_id])
            else:
                # print(ds_id)
                ft.SetField(FLD_DSLINKNO, -1)
            layer_reach.SetFeature(ft)
            ft = layer_reach.GetNextFeature()
        ds_reach.ExecuteSQL(str('REPACK reach'))
        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
        return id_map
示例#14
0
    def add_channel_width_depth_to_shp(reach_shp_file, stream_link_file,
                                       width_file, depth_file):
        """Calculate average channel width and depth, and add or modify the attribute table
           of reach.shp
        """
        stream_link = RasterUtilClass.read_raster(stream_link_file)
        n_rows = stream_link.nRows
        n_cols = stream_link.nCols
        nodata_value = stream_link.noDataValue
        data_stream = stream_link.data

        width = RasterUtilClass.read_raster(width_file)
        width_data = width.data
        depth = RasterUtilClass.read_raster(depth_file)
        depth_data = depth.data

        ch_width_dic = dict()
        ch_depth_dic = dict()
        ch_num_dic = dict()

        for i in range(n_rows):
            for j in range(n_cols):
                if abs(data_stream[i][j] - nodata_value) <= UTIL_ZERO:
                    continue
                tmpid = int(data_stream[i][j])
                ch_num_dic.setdefault(tmpid, 0)
                ch_width_dic.setdefault(tmpid, 0)
                ch_depth_dic.setdefault(tmpid, 0)

                ch_num_dic[tmpid] += 1
                ch_width_dic[tmpid] += width_data[i][j]
                ch_depth_dic[tmpid] += depth_data[i][j]

        for k in ch_num_dic:
            ch_width_dic[k] /= ch_num_dic[k]
            ch_depth_dic[k] /= ch_num_dic[k]

        # add channel width and depth fields to reach shp file or update values if the fields exist
        ds_reach = ogr_Open(reach_shp_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        i_link = layer_def.GetFieldIndex(str(ImportReaches2Mongo._LINKNO))
        i_width = layer_def.GetFieldIndex(str(ImportReaches2Mongo._WIDTH))
        i_depth = layer_def.GetFieldIndex(str(ImportReaches2Mongo._DEPTH))
        if i_width < 0:
            new_field = ogr_FieldDefn(str(ImportReaches2Mongo._WIDTH), OFTReal)
            layer_reach.CreateField(new_field)
        if i_depth < 0:
            new_field = ogr_FieldDefn(str(ImportReaches2Mongo._DEPTH), OFTReal)
            layer_reach.CreateField(new_field)

        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            tmpid = ft.GetFieldAsInteger(i_link)
            w = 5.
            d = 1.5
            if tmpid in ch_width_dic:
                w = ch_width_dic[tmpid]
            if tmpid in ch_depth_dic:
                d = ch_depth_dic[tmpid]
            ft.SetField(str(ImportReaches2Mongo._WIDTH), w)
            ft.SetField(str(ImportReaches2Mongo._DEPTH), d)
            layer_reach.SetFeature(ft)
            ft = layer_reach.GetNextFeature()

        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
示例#15
0
    def add_channel_width_depth_to_shp(reach_shp_file, stream_link_file, width_file, depth_file):
        """Calculate average channel width and depth, and add or modify the attribute table
           of reach.shp
        """
        stream_link = RasterUtilClass.read_raster(stream_link_file)
        n_rows = stream_link.nRows
        n_cols = stream_link.nCols
        nodata_value = stream_link.noDataValue
        data_stream = stream_link.data

        width = RasterUtilClass.read_raster(width_file)
        width_data = width.data
        depth = RasterUtilClass.read_raster(depth_file)
        depth_data = depth.data

        ch_width_dic = dict()
        ch_depth_dic = dict()
        ch_num_dic = dict()

        for i in range(n_rows):
            for j in range(n_cols):
                if abs(data_stream[i][j] - nodata_value) <= UTIL_ZERO:
                    continue
                tmpid = int(data_stream[i][j])
                ch_num_dic.setdefault(tmpid, 0)
                ch_width_dic.setdefault(tmpid, 0)
                ch_depth_dic.setdefault(tmpid, 0)

                ch_num_dic[tmpid] += 1
                ch_width_dic[tmpid] += width_data[i][j]
                ch_depth_dic[tmpid] += depth_data[i][j]

        for k in ch_num_dic:
            ch_width_dic[k] /= ch_num_dic[k]
            ch_depth_dic[k] /= ch_num_dic[k]

        # add channel width and depth fields to reach shp file or update values if the fields exist
        ds_reach = ogr_Open(reach_shp_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        i_link = layer_def.GetFieldIndex(ImportReaches2Mongo._LINKNO)
        i_width = layer_def.GetFieldIndex(ImportReaches2Mongo._WIDTH)
        i_depth = layer_def.GetFieldIndex(ImportReaches2Mongo._DEPTH)
        if i_width < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._WIDTH, OFTReal)
            layer_reach.CreateField(new_field)
        if i_depth < 0:
            new_field = ogr_FieldDefn(ImportReaches2Mongo._DEPTH, OFTReal)
            layer_reach.CreateField(new_field)

        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            tmpid = ft.GetFieldAsInteger(i_link)
            w = 5.
            d = 1.5
            if tmpid in ch_width_dic:
                w = ch_width_dic[tmpid]
            if tmpid in ch_depth_dic:
                d = ch_depth_dic[tmpid]
            ft.SetField(ImportReaches2Mongo._WIDTH, w)
            ft.SetField(ImportReaches2Mongo._DEPTH, d)
            layer_reach.SetFeature(ft)
            ft = layer_reach.GetNextFeature()

        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
示例#16
0
    def serialize_streamnet(streamnet_file, output_reach_file):
        """Eliminate reach with zero length and return the reach ID map.
        Args:
            streamnet_file: original stream net ESRI shapefile
            output_reach_file: serialized stream net, ESRI shapefile

        Returns:
            id pairs {origin: newly assigned}
        """
        FileClass.copy_files(streamnet_file, output_reach_file)
        ds_reach = ogr_Open(output_reach_file, update=True)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        i_link = layer_def.GetFieldIndex(FLD_LINKNO)
        i_link_downslope = layer_def.GetFieldIndex(FLD_DSLINKNO)
        i_len = layer_def.GetFieldIndex(REACH_LENGTH)

        old_id_list = []
        # there are some reaches with zero length.
        # this program will remove these zero-length reaches
        # output_dic is used to store the downstream reaches of these zero-length
        # reaches
        output_dic = {}
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            link_id = ft.GetFieldAsInteger(i_link)
            reach_len = ft.GetFieldAsDouble(i_len)
            if link_id not in old_id_list:
                if reach_len < DELTA:
                    downstream_id = ft.GetFieldAsInteger(i_link_downslope)
                    output_dic[link_id] = downstream_id
                else:
                    old_id_list.append(link_id)

            ft = layer_reach.GetNextFeature()
        old_id_list.sort()

        id_map = {}
        for i, old_id in enumerate(old_id_list):
            id_map[old_id] = i + 1
        # print(id_map)
        # change old ID to new ID
        layer_reach.ResetReading()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            link_id = ft.GetFieldAsInteger(i_link)
            if link_id not in id_map:
                layer_reach.DeleteFeature(ft.GetFID())
                ft = layer_reach.GetNextFeature()
                continue

            ds_id = ft.GetFieldAsInteger(i_link_downslope)
            ds_id = output_dic.get(ds_id, ds_id)
            ds_id = output_dic.get(ds_id, ds_id)

            ft.SetField(FLD_LINKNO, id_map[link_id])
            if ds_id in id_map:
                ft.SetField(FLD_DSLINKNO, id_map[ds_id])
            else:
                # print(ds_id)
                ft.SetField(FLD_DSLINKNO, -1)
            layer_reach.SetFeature(ft)
            ft = layer_reach.GetNextFeature()
        ds_reach.ExecuteSQL("REPACK reach")
        layer_reach.SyncToDisk()
        ds_reach.Destroy()
        del ds_reach
        return id_map
示例#17
0
    def down_stream(reach_shp, is_taudem=True):
        """Construct stream order layers etc.
        Args:
            reach_shp: reach ESRI shapefile.
            is_taudem: is TauDEM or not, true is default.

        Returns:
            down_stream_dic: the key is stream id, and value is its downstream id
            downstream_up_order_dic: from outlet up stream dict
            upstream_down_order_dic: from source down stream dict
            depth_dic: stream depth dict
            slope_dic: stream slope dict
            width_dic: stream width dict
            len_dic: stream length dict
        """
        down_stream_dic = {}
        depth_dic = {}
        slope_dic = {}
        width_dic = {}
        len_dic = {}
        ds_reach = ogr_Open(reach_shp)
        layer_reach = ds_reach.GetLayer(0)
        layer_def = layer_reach.GetLayerDefn()
        if not is_taudem:  # For ArcSWAT
            ImportReaches2Mongo._LINKNO = 'FROM_NODE'
            ImportReaches2Mongo._DSLINKNO = 'TO_NODE'
            ImportReaches2Mongo._SLOPE = 'Slo2'  # TauDEM: Slope (tan); ArcSWAT: Slo2 (100*tan)
            ImportReaches2Mongo._LENGTH = 'Len2'  # TauDEM: Length; ArcSWAT: Len2
        i_from = layer_def.GetFieldIndex(ImportReaches2Mongo._LINKNO)
        i_to = layer_def.GetFieldIndex(ImportReaches2Mongo._DSLINKNO)
        i_depth = layer_def.GetFieldIndex(ImportReaches2Mongo._DEPTH)
        i_slope = layer_def.GetFieldIndex(ImportReaches2Mongo._SLOPE)
        i_width = layer_def.GetFieldIndex(ImportReaches2Mongo._WIDTH)
        i_len = layer_def.GetFieldIndex(ImportReaches2Mongo._LENGTH)

        g = nx.DiGraph()
        ft = layer_reach.GetNextFeature()
        while ft is not None:
            node_from = ft.GetFieldAsInteger(i_from)
            node_to = ft.GetFieldAsInteger(i_to)
            if i_depth > -1:
                depth_dic[node_from] = ft.GetFieldAsDouble(i_depth)
            else:
                depth_dic[node_from] = 1

            if i_depth > -1:
                slope_dic[node_from] = ft.GetFieldAsDouble(i_slope)
                if slope_dic[node_from] < MINI_SLOPE:
                    slope_dic[node_from] = MINI_SLOPE
            else:
                slope_dic[node_from] = MINI_SLOPE

            if i_width > -1:
                width_dic[node_from] = ft.GetFieldAsDouble(i_width)
            else:
                width_dic[node_from] = 10

            len_dic[node_from] = ft.GetFieldAsDouble(i_len)
            down_stream_dic[node_from] = node_to
            if node_to > 0:
                # print node_from, node_to
                g.add_edge(node_from, node_to)
            ft = layer_reach.GetNextFeature()

        # find outlet subbasin
        outlet = -1
        for node in g.nodes():
            if g.out_degree(node) == 0:
                outlet = node
        if outlet < 0:
            raise ValueError("Can't find outlet subbasin ID, please check!")
        print('outlet subbasin:%d' % outlet)

        # assign order from outlet to upstream subbasins
        downstream_up_order_dic = {}
        ImportReaches2Mongo.stream_orders_from_outlet_up(
            downstream_up_order_dic, g, outlet, 1)
        # find the maximum order number
        max_order = 0
        for k, v in downstream_up_order_dic.items():
            if v > max_order:
                max_order = v
        # reserve the order number
        for k, v in downstream_up_order_dic.items():
            downstream_up_order_dic[k] = max_order - v + 1

        # assign order from the source subbasins
        upstream_down_order_dic = dict()
        order_num = 1
        nodelist = g.nodes()
        while len(nodelist) != 0:
            nodelist = g.nodes()
            del_list = []
            for node in nodelist:
                if g.in_degree(node) == 0:
                    upstream_down_order_dic[node] = order_num
                    del_list.append(node)
            for item in del_list:
                g.remove_node(item)
            order_num += 1

        return (down_stream_dic, downstream_up_order_dic,
                upstream_down_order_dic, depth_dic, slope_dic, width_dic,
                len_dic)