示例#1
0
def _pixel_sz_trans(ds: gdal.Dataset, ps: float) -> gdal.Dataset:
    """ Resize the image by pixel size. """

    ds_trans = ds.GetGeoTransform()
    factor = ds_trans[1] / ps

    if list(ds_trans)[:2] == [0.0, 1.0] or round(factor, 2) == 1:
        return ds

    ds_proj = ds.GetProjection()
    ds_dtype = ds.GetRasterBand(1).DataType
    width, height = ds.RasterXSize, ds.RasterYSize

    ts_trans = list(ds_trans)
    ts_trans[1] = ps
    ts_trans[5] = -ps

    mem_drv = gdal.GetDriverByName('MEM')
    dst_ds = mem_drv.Create('', int(width * factor), int(height * factor),
                            ds.RasterCount, ds_dtype)
    dst_ds.SetProjection(ds_proj)
    dst_ds.SetGeoTransform(ts_trans)
    gdal.ReprojectImage(ds, dst_ds, ds_proj, ds_proj,
                        gdalconst.GRA_CubicSpline)

    return dst_ds
示例#2
0
def make_raster(in_ds: gdal.Dataset,
                fn: str,
                data: np.ndarray,
                data_type: object,
                Nodata=None) -> gdal.Dataset:
    """Create a one-band GeoTIFF.

    Parameters:
    ------------
    in_ds       - datasource to copy projection and geotransfrom from
    fn          - path to the file to create
    data        - NUmpy array containing data to write
    data_type   - output data type
    nodata      - optional NoData value

    Returns:
    ------------
    out_ds      - datasource to output
    """
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(fn, in_ds.RasterXSize, in_ds.RasterYSize, 1,
                           data_type)
    out_ds: gdal.Dataset
    # 从输入数据源中复制投影(坐标系信息)
    out_ds.SetProjection(in_ds.GetProjection())
    # 从输入数据源中复制地理变换
    out_ds.SetGeoTransform(in_ds.GetGeoTransform())
    out_band = out_ds.GetRasterBand(1)
    out_band: gdal.Band
    if Nodata is not None:
        out_band.SetNoDataValue(Nodata)
    out_band.WriteArray(data)
    out_band.FlushCache()
    out_band.ComputeStatistics(False)
    return out_ds
示例#3
0
def write_mask_to_file(f: gdal.Dataset, file_name: str,
                       mask: np.ndarray) -> None:
    (width, height) = mask.shape
    out_image = gdal.GetDriverByName('GTiff').Create(file_name,
                                                     height,
                                                     width,
                                                     bands=1)
    out_image.SetProjection(f.GetProjection())
    out_image.SetGeoTransform(f.GetGeoTransform())
    out_image.GetRasterBand(1).WriteArray(mask)
    out_image.FlushCache()
示例#4
0
def world_to_pixel(image_dataset: gdal.Dataset, longitude: float,
                   latitude: float) -> (int, int):
    geotransform = image_dataset.GetGeoTransform()

    ulx, uly = geotransform[0], geotransform[3]
    x_dist = geotransform[1]

    x = np.round((longitude - uly) / x_dist).astype(np.int)
    y = np.round((uly - latitude) / x_dist).astype(np.int)

    return x, y
示例#5
0
def _create_blank_raster(
    in_data_set: gdal.Dataset,
    out_raster_path: Path,
    nr_bands: int = 1,
    no_data: float = np.nan,
    e_type: int = 6,
):
    """Takes input data set and creates new raster. It copies input data set size, projection and geo info."""
    gtiff_driver = gdal.GetDriverByName("GTiff")
    band = in_data_set.GetRasterBand(1)
    x_size = band.XSize  # number of columns
    y_size = band.YSize  # number of rows
    out_ds = gtiff_driver.Create(out_raster_path.as_posix(),
                                 xsize=x_size,
                                 ysize=y_size,
                                 bands=nr_bands,
                                 eType=e_type,
                                 options=["BIGTIFF=IF_NEEDED"])
    out_ds.SetProjection(in_data_set.GetProjection())
    out_ds.SetGeoTransform(in_data_set.GetGeoTransform())
    out_ds.GetRasterBand(1).SetNoDataValue(no_data)
    out_ds.FlushCache()
    out_ds = None
示例#6
0
def gdal_to_json(ds: gdal.Dataset):
    gt = ds.GetGeoTransform(can_return_null=True)
    xsize = ds.RasterXSize
    ysize = ds.RasterYSize
    srs = get_srs(ds)
    srs = srs.ExportToProj4()
    minx = gt[0] + gt[1] * 0 + gt[2] * 0
    miny = gt[3] + gt[4] * 0 + gt[5] * 0
    maxx = gt[0] + gt[1] * xsize + gt[2] * ysize
    maxy = gt[3] + gt[4] * xsize + gt[5] * ysize
    bbox = miny, minx, maxy, maxx
    band_list = range(1, ds.RasterCount + 1)
    data = [
        ds.ReadAsArray(band_list=[bnd]).ravel().tolist() for bnd in band_list
    ]
    ndv = [ds.GetRasterBand(i).GetNoDataValue() for i in band_list]
    result = dict(bbox=bbox,
                  gt=gt,
                  srs=srs,
                  size=(xsize, ysize),
                  data=data,
                  ndv=ndv)
    return result
示例#7
0
def get_geotransform_and_size(
        ds: gdal.Dataset) -> Tuple[GeoTransform, Tuple[int, int]]:
    return ds.GetGeoTransform(), (ds.RasterXSize, ds.RasterYSize)
示例#8
0
def shp_file_to_csv(shp_ds: ogr.DataSource, raster_ds: gdal.Dataset) -> list:
    """
    return a list contain the tuple (x, y, class) of the point contain of polygon shp_ds.
    (x, y) is the coords of the shp_ds's spatial reference
    (the same as the raster_ds's spatial reference),
    the first element represent the x value,
    the second element represent the y value,
    the third element represent the class value.
    The class value is the code of one land cover.

    Parameters:
    --------------
    shp_ds      - shape file(GeometryType is Polygon) Data source
    raster_ds   - raster file Dataset

    Returns:
    -------------
    train_data_coords       - The train data's list contains the  item (x,y, class)

    """

    # 创建一个shp
    train_data_coords = []

    # 获取首个图层
    poly_lyr: ogr.Layer = shp_ds.GetLayer(0)
    # 获取shp文件的坐标系
    shp_osr: osr.SpatialReference = poly_lyr.GetSpatialRef()
    shp_osr.GetAttrValue('AUTHORITY', 1)  # ??? 获取对象属性值

    # 获取Gtiff文件坐标系
    raster_osr = osr.SpatialReference()
    raster_osr.ImportFromWkt(raster_ds.GetProjection())

    # 获取Gtiff文件地理变换
    gtiff_geotrans = raster_ds.GetGeoTransform()

    if raster_osr.GetAttrValue('AUTHORITY', 1) != shp_osr.GetAttrValue(
            'AUTHORITY', 1):
        print(
            'Error: The shape file and the raster file have the differnet spatial refer'
        )
        return train_data_coords

    inv_gt = gdal.InvGeoTransform(gtiff_geotrans)

    # 获取要素的数量
    feat_count = poly_lyr.GetFeatureCount()
    # 保存训练集在栅格上的坐标(X,Y)以及地理编码值(class),需要三列

    # 遍历图层获取每一要素
    for feat_i in range(feat_count):
        # 获取要素
        poly_feat: ogr.Feature = poly_lyr.GetFeature(feat_i)

        # 提取类别编码值
        if 'class' not in poly_feat.keys():
            print("Error: The shape file don't have the 'class' Field")
            break
        name = poly_feat.GetField('class')
        # 从要素中获取多边形几何(是一个环)
        poly_geom: ogr.Geometry = poly_feat.geometry()
        if poly_geom.GetGeometryName() != 'POLYGON':
            print("Error: The geometry type of shape file isn't the polygon.")
            break
        for ring_i in range(poly_geom.GetGeometryCount()):
            # 获取多边形几何的第i个环
            ring: ogr.Geometry = poly_geom.GetGeometryRef(ring_i)

            # 获取几何多边形的边界(西东南北)
            left, right, lower, upper = ring.GetEnvelope()
            points = ring.GetPoints()
            # 判断点在多边形上
            # int OGRPolygon::PointOnSurface(OGRPoint * poPoint) const [virtual]
            for px in np.arange(left, right, gtiff_geotrans[1]):
                for py in np.arange(upper, lower, gtiff_geotrans[5]):
                    # 创建一个点
                    if point_in_poly(px, py, points):
                        offsets = gdal.ApplyGeoTransform(inv_gt, px, py)
                        # 转换为像素坐标(整数值)
                        xoff, yoff = map(int, offsets)
                        train_data_coords.append((xoff, yoff, px, py, name))
                        train_data_coords.append((px, py, name))
    return train_data_coords
示例#9
0
def shp_files_to_csv(shp_files: list, raster_ds: gdal.Dataset,
                     out_csv_file) -> list:
    """
    return a list contain the tuple (x, y, class) of the point contain of polygon shp_ds.
    (x, y) is the coords of the shp_ds's spatial reference
    (the same as the raster_ds's spatial reference),
    the first element represent the x value,
    the second element represent the y value,
    the third element represent the class value.
    The class value is the code of one land cover.

    Parameters:
    --------------
    shp_files   - list of shape file(GeometryType is Polygon) Data source
    raster_ds   - raster file Dataset

    Returns:
    -------------
    train_data_coords       - The train data's list contains the  item (x,y, class)

    """

    # 创建一个shp
    train_data_coords = []

    for i in range(len(shp_files)):
        code = i + 1
        shp_ds = ogr.Open(shp_files[i])
        print(shp_files[i])
        # 获取首个图层
        poly_lyr: ogr.Layer = shp_ds.GetLayer(0)
        # 获取shp文件的坐标系
        shp_osr: osr.SpatialReference = poly_lyr.GetSpatialRef()
        shp_osr.GetAttrValue('AUTHORITY', 1)

        # 获取Gtiff文件坐标系
        raster_osr = osr.SpatialReference()  # 获取地理参考系统
        raster_osr.ImportFromWkt(
            raster_ds.GetProjection())  # 从一个WKT定义的坐标系统来构造一个SpatialReference类对象

        # 获取Gtiff文件地理变换
        gtiff_geotrans = raster_ds.GetGeoTransform()

        if raster_osr.GetAttrValue('AUTHORITY', 1) != shp_osr.GetAttrValue(
                'AUTHORITY', 1):
            print(
                'Error: The shape file and the raster file have the differnet spatial refer'
            )
            return train_data_coords

        inv_gt = gdal.InvGeoTransform(gtiff_geotrans)

        # 获取要素的数量
        feat_count = poly_lyr.GetFeatureCount()
        # 保存训练集在栅格上的坐标(X,Y)以及地理编码值(class),需要三列

        # 遍历图层获取每一要素
        for feat_i in range(feat_count):
            # 获取要素
            poly_feat: ogr.Feature = poly_lyr.GetFeature(feat_i)

            # 没有编码值
            # 从要素中获取多边形几何(是一个环)
            poly_geom: ogr.Geometry = poly_feat.geometry()
            if poly_geom.GetGeometryName() != 'POLYGON':
                print(
                    "Error: The geometry type of shape file isn't the polygon."
                )
                break
            for ring_i in range(poly_geom.GetGeometryCount()):
                # 获取多边形几何的第i个环
                ring: ogr.Geometry = poly_geom.GetGeometryRef(ring_i)

                # 获取几何多边形的边界(西东南北)
                left, right, lower, upper = ring.GetEnvelope()
                points = ring.GetPoints()
                # 判断点在多边形上
                # int OGRPolygon::PointOnSurface(OGRPoint * poPoint) const [virtual]
                for px in np.arange(left, right, gtiff_geotrans[1]):
                    for py in np.arange(upper, lower, gtiff_geotrans[5]):
                        # 创建一个点
                        if point_in_poly(px, py, points):
                            offsets = gdal.ApplyGeoTransform(inv_gt, px, py)
                            # 转换为像素坐标(整数值)
                            xoff, yoff = map(int, offsets)
                            train_data_coords.append(
                                (xoff, yoff, px, py, code))
                            #train_data_coords.append((px, py, code))
        df = pd.DataFrame(train_data_coords,
                          columns=['xoff', 'yoff', 'px', 'py', 'class'])
        df.to_csv(out_csv_file)
        del shp_ds
    print('train.csv have successfully write in.')
    return train_data_coords
示例#10
0
文件: process_grib.py 项目: bcgov/wps
def get_dataset_geometry(dataset: gdal.Dataset) -> (List[int], List[int]):
    """ Get the geometry info (origin and pixel size) of the dataset.
    """
    return dataset.GetGeoTransform()
示例#11
0
    def _load_geotransform(self, image_dataset: gdal.Dataset) -> Geotransform:

        return Geotransform.from_tuple(image_dataset.GetGeoTransform())