def __add_dummy(self, ds: ogr.DataSource): '''Adds a dummy layer if needed''' if (ds.GetLayerCount() % 2 == 0 and self.chart in ['ctmed', 'cpmed', 'oimed']): self.need_dummy = True params = (self.const['burn_field'], self.const['land']) ds_dummy = make_dummy(self.aoi, 'LAND_DUMMY', params) layer = ds_dummy.GetLayer() ds.CopyLayer(layer, layer.GetName(), ['OVERWRITE=YES']) ds_dummy = None else: self.need_dummy = False
def is_datasource_valid(self, ds: ogr.DataSource) -> bool: layer_count = ds.GetLayerCount() if layer_count == 1: layer = ds.GetLayer(0) # Only one layer else: # TODO: add support return False defn = layer.GetLayerDefn() for i in range(defn.GetFieldCount()): field_name = defn.GetFieldDefn(i).GetName() if field_name.lower() == self.temporal_field: field_type_code = defn.GetFieldDefn(i).GetType() return field_type_code in self.DATETIME_TYPES return False
def close_shapefile(shp: ogr.DataSource) -> str: '''Closes the shapefile''' logging.info(f'closing shapefile {shp.GetName()}') try: shp.Release() shp = None except Exception as e: raise baseException(f'shapefile: {shp.GetName()} could not be closed', baseException.ERR_CODE_LEVEL, e)
def __rasterize(self, ds: ogr.DataSource) -> gdal.Dataset: '''Rasterize and reprojects all layers into bands''' num_bands = 1 tif = create_tif('/vsimem/rasters.tif', self.extent, self.srs, self.pixel_size, ds.GetLayerCount()) set_nodata(tif, self.const['nodata']) for layer in ds: reproject(layer, self.srs) rasterize(layer, tif, num_bands, self.const['burn_field']) num_bands += 1 return tif
def update_from_ogr_data_source(self, ds: ogr.DataSource) -> None: layer_count = ds.GetLayerCount() if layer_count == 1: layer = ds.GetLayer(0) # Only one layer else: # TODO: add support return defn = layer.GetLayerDefn() self.layer_name = layer.GetName() fields = [] for i in range(defn.GetFieldCount()): field_name = defn.GetFieldDefn(i).GetName() field_type_code = defn.GetFieldDefn(i).GetType() fields.append(field_name) if field_name.lower() in self.DATETIME_FIELDS: if field_type_code not in self.DATETIME_TYPES: self.time_field_idx = i self.fields = fields
def ogr_to_ogr( bbox: BBOX, src_layers: List[ogr.Layer], dst_datasource: ogr.DataSource, dst_layer_name: str, dst_crs_code: str, ) -> None: gen_srs = ogr.osr.SpatialReference() gen_srs.ImportFromEPSG(int(dst_crs_code.split(":")[-1])) for i, src_layer in enumerate(src_layers): if src_layer.GetGeomType() == ogr.wkbNone: logging.debug( f"Layer {src_layer.GetName()} does not contain geometries, skipping" ) continue src_layer_srs = src_layer.GetSpatialRef() clip_geometry = bbox.transform_as_geom( f"{src_layer_srs.GetAuthorityName(None)}:{src_layer_srs.GetAuthorityCode(None)}" ) if i == 0: gen_layer = dst_datasource.CreateLayer( dst_layer_name, gen_srs, src_layer.GetLayerDefn().GetGeomType()) for j in range(src_layer.GetLayerDefn().GetFieldCount()): field_defn = src_layer.GetLayerDefn().GetFieldDefn(j) gen_layer.CreateField(field_defn) src_layer.SetSpatialFilter(clip_geometry) logging.debug( f"Clipped src_layer to {src_layer.GetFeatureCount()} features") while filtered_feature := src_layer.GetNextFeature(): contained_feature = filtered_feature.Clone() contained_geometry = contained_feature.GetGeometryRef( ).Intersection(clip_geometry) if contained_geometry: contained_geometry.AssignSpatialReference( contained_feature.GetGeometryRef().GetSpatialReference() ) # geometry loses its spatial ref during Intersection contained_geometry.TransformTo(gen_srs) contained_feature.SetGeometryDirectly(contained_geometry) gen_layer.CreateFeature(contained_feature)
def make_dummy(ds: ogr.DataSource, name: str, burn: Tuple) -> ogr.DataSource: '''Returns a dummy layer for CTMED populated with burn[1] values''' logging.info(f'creating dummy: {name} layer') try: # create dummy layer mem = create_shapefile('memory', 'MEMORY') dummy = mem.CopyLayer(ds.GetLayer(), name, ['OVERWRITE=YES']) # add burn field and populate fdefn = ogr.FieldDefn(burn[0], ogr.OFTInteger) dummy.CreateField(fdefn) dummy.ResetReading() for ft in dummy: ft.SetField(burn[0], burn[1]) dummy.SetFeature(ft) return mem except Exception as e: close_shapefile(ds) raise baseException(f'dummy: {name} layer could not be created.', baseException.ERR_CODE_LEVEL, e)
def get_first_layer(ds: ogr.DataSource) -> ogr.Layer: return ds.GetLayer()
def create_indexes(in_name: str, out_dataset: DataSource, cols: Tuple[str, ...]) -> None: """Create index on each column requested""" for column in cols: out_dataset.ExecuteSQL(f"CREATE INDEX ON {in_name} USING {column}")
def round_extent(ds: ogr.DataSource, pixel_size: int) -> List[float]: '''Gets the rounded extent''' layer = ds.GetLayer() extent = layer.GetExtent() return [round(e / pixel_size) * pixel_size for e in extent]
def shp_file_to_csv(shp_ds: ogr.DataSource, raster_ds: gdal.Dataset) -> list: """ return a list contain the tuple (x, y, class) of the point contain of polygon shp_ds. (x, y) is the coords of the shp_ds's spatial reference (the same as the raster_ds's spatial reference), the first element represent the x value, the second element represent the y value, the third element represent the class value. The class value is the code of one land cover. Parameters: -------------- shp_ds - shape file(GeometryType is Polygon) Data source raster_ds - raster file Dataset Returns: ------------- train_data_coords - The train data's list contains the item (x,y, class) """ # 创建一个shp train_data_coords = [] # 获取首个图层 poly_lyr: ogr.Layer = shp_ds.GetLayer(0) # 获取shp文件的坐标系 shp_osr: osr.SpatialReference = poly_lyr.GetSpatialRef() shp_osr.GetAttrValue('AUTHORITY', 1) # ??? 获取对象属性值 # 获取Gtiff文件坐标系 raster_osr = osr.SpatialReference() raster_osr.ImportFromWkt(raster_ds.GetProjection()) # 获取Gtiff文件地理变换 gtiff_geotrans = raster_ds.GetGeoTransform() if raster_osr.GetAttrValue('AUTHORITY', 1) != shp_osr.GetAttrValue( 'AUTHORITY', 1): print( 'Error: The shape file and the raster file have the differnet spatial refer' ) return train_data_coords inv_gt = gdal.InvGeoTransform(gtiff_geotrans) # 获取要素的数量 feat_count = poly_lyr.GetFeatureCount() # 保存训练集在栅格上的坐标(X,Y)以及地理编码值(class),需要三列 # 遍历图层获取每一要素 for feat_i in range(feat_count): # 获取要素 poly_feat: ogr.Feature = poly_lyr.GetFeature(feat_i) # 提取类别编码值 if 'class' not in poly_feat.keys(): print("Error: The shape file don't have the 'class' Field") break name = poly_feat.GetField('class') # 从要素中获取多边形几何(是一个环) poly_geom: ogr.Geometry = poly_feat.geometry() if poly_geom.GetGeometryName() != 'POLYGON': print("Error: The geometry type of shape file isn't the polygon.") break for ring_i in range(poly_geom.GetGeometryCount()): # 获取多边形几何的第i个环 ring: ogr.Geometry = poly_geom.GetGeometryRef(ring_i) # 获取几何多边形的边界(西东南北) left, right, lower, upper = ring.GetEnvelope() points = ring.GetPoints() # 判断点在多边形上 # int OGRPolygon::PointOnSurface(OGRPoint * poPoint) const [virtual] for px in np.arange(left, right, gtiff_geotrans[1]): for py in np.arange(upper, lower, gtiff_geotrans[5]): # 创建一个点 if point_in_poly(px, py, points): offsets = gdal.ApplyGeoTransform(inv_gt, px, py) # 转换为像素坐标(整数值) xoff, yoff = map(int, offsets) train_data_coords.append((xoff, yoff, px, py, name)) train_data_coords.append((px, py, name)) return train_data_coords