Example #1
0
def _pixel_sz_trans(ds: gdal.Dataset, ps: float) -> gdal.Dataset:
    """ Resize the image by pixel size. """

    ds_trans = ds.GetGeoTransform()
    factor = ds_trans[1] / ps

    if list(ds_trans)[:2] == [0.0, 1.0] or round(factor, 2) == 1:
        return ds

    ds_proj = ds.GetProjection()
    ds_dtype = ds.GetRasterBand(1).DataType
    width, height = ds.RasterXSize, ds.RasterYSize

    ts_trans = list(ds_trans)
    ts_trans[1] = ps
    ts_trans[5] = -ps

    mem_drv = gdal.GetDriverByName('MEM')
    dst_ds = mem_drv.Create('', int(width * factor), int(height * factor),
                            ds.RasterCount, ds_dtype)
    dst_ds.SetProjection(ds_proj)
    dst_ds.SetGeoTransform(ts_trans)
    gdal.ReprojectImage(ds, dst_ds, ds_proj, ds_proj,
                        gdalconst.GRA_CubicSpline)

    return dst_ds
Example #2
0
def make_raster(in_ds: gdal.Dataset,
                fn: str,
                data: np.ndarray,
                data_type: object,
                Nodata=None) -> gdal.Dataset:
    """Create a one-band GeoTIFF.

    Parameters:
    ------------
    in_ds       - datasource to copy projection and geotransfrom from
    fn          - path to the file to create
    data        - NUmpy array containing data to write
    data_type   - output data type
    nodata      - optional NoData value

    Returns:
    ------------
    out_ds      - datasource to output
    """
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(fn, in_ds.RasterXSize, in_ds.RasterYSize, 1,
                           data_type)
    out_ds: gdal.Dataset
    # 从输入数据源中复制投影(坐标系信息)
    out_ds.SetProjection(in_ds.GetProjection())
    # 从输入数据源中复制地理变换
    out_ds.SetGeoTransform(in_ds.GetGeoTransform())
    out_band = out_ds.GetRasterBand(1)
    out_band: gdal.Band
    if Nodata is not None:
        out_band.SetNoDataValue(Nodata)
    out_band.WriteArray(data)
    out_band.FlushCache()
    out_band.ComputeStatistics(False)
    return out_ds
Example #3
0
    def save_prj_file(cls, output_path: str, ds: gdal.Dataset) -> bool:
        src_srs = osr.SpatialReference()
        src_srs.ImportFromWkt(ds.GetProjection())
        src_srs.MorphToESRI()
        src_wkt = src_srs.ExportToWkt()

        prj_file = open(os.path.splitext(output_path)[0] + '.prj', 'wt')
        prj_file.write(src_wkt)
        prj_file.close()
        return True
Example #4
0
    def load_from_dataset(self, image_dataset: gdal.Dataset) -> Image:

        geo_transform = self._load_geotransform(image_dataset)
        projection = image_dataset.GetProjection()
        pixels = image_dataset.ReadAsArray()

        if pixels.ndim > 2:
            pixels = pixels.transpose(1, 2, 0)

        return Image(pixels, geo_transform, projection)
Example #5
0
def write_mask_to_file(f: gdal.Dataset, file_name: str,
                       mask: np.ndarray) -> None:
    (width, height) = mask.shape
    out_image = gdal.GetDriverByName('GTiff').Create(file_name,
                                                     height,
                                                     width,
                                                     bands=1)
    out_image.SetProjection(f.GetProjection())
    out_image.SetGeoTransform(f.GetGeoTransform())
    out_image.GetRasterBand(1).WriteArray(mask)
    out_image.FlushCache()
Example #6
0
 def _prepare_bound_checker(self, grib_tmp_700: gdal.Dataset):
     """ Prepare the boundary checker. """
     if not self.bound_checker:
         logger.info('Creating bound checker.')
         padf_transform = get_dataset_geometry(grib_tmp_700)
         crs = CRS.from_string(grib_tmp_700.GetProjection())
         # Create a transformer to go from whatever the raster is, to geographic coordinates.
         raster_to_geo_transformer = get_transformer(crs, NAD83_CRS)
         self.bound_checker = BoundingBoxChecker(padf_transform,
                                                 raster_to_geo_transformer)
     else:
         logger.info('Re-using bound checker.')
Example #7
0
    def load_from_dataset_and_clip(self, image_dataset: gdal.Dataset,
                                   extent: GeoPolygon) -> Image:

        geo_transform = self._load_geotransform(image_dataset)
        pixel_polygon = extent.to_pixel(geo_transform)

        bounds = [int(bound) for bound in pixel_polygon.polygon.bounds]

        pixels = image_dataset.ReadAsArray(bounds[0], bounds[1],
                                           bounds[2] - bounds[0],
                                           bounds[3] - bounds[1])
        subset_geo_transform = geo_transform.subset(x=bounds[0], y=bounds[1])
        pixel_polygon = extent.to_pixel(subset_geo_transform)

        if pixels.ndim > 2:
            pixels = pixels.transpose(1, 2, 0)

        return Image(pixels, subset_geo_transform, image_dataset.GetProjection())\
            .clip_with(pixel_polygon, mask_value=0)
Example #8
0
def _create_blank_raster(
    in_data_set: gdal.Dataset,
    out_raster_path: Path,
    nr_bands: int = 1,
    no_data: float = np.nan,
    e_type: int = 6,
):
    """Takes input data set and creates new raster. It copies input data set size, projection and geo info."""
    gtiff_driver = gdal.GetDriverByName("GTiff")
    band = in_data_set.GetRasterBand(1)
    x_size = band.XSize  # number of columns
    y_size = band.YSize  # number of rows
    out_ds = gtiff_driver.Create(out_raster_path.as_posix(),
                                 xsize=x_size,
                                 ysize=y_size,
                                 bands=nr_bands,
                                 eType=e_type,
                                 options=["BIGTIFF=IF_NEEDED"])
    out_ds.SetProjection(in_data_set.GetProjection())
    out_ds.SetGeoTransform(in_data_set.GetGeoTransform())
    out_ds.GetRasterBand(1).SetNoDataValue(no_data)
    out_ds.FlushCache()
    out_ds = None
Example #9
0
def test_projection_is_wgs84(gdal_dataset: gdal.Dataset):
    assert gdal_dataset.GetProjection()[8:14] == 'WGS 84'
Example #10
0
def get_srs_from_ds(ds: gdal.Dataset) -> osr.SpatialReference:
    srs = osr.SpatialReference()
    srs.ImportFromWkt(ds.GetProjection())
    return srs
Example #11
0
def shp_file_to_csv(shp_ds: ogr.DataSource, raster_ds: gdal.Dataset) -> list:
    """
    return a list contain the tuple (x, y, class) of the point contain of polygon shp_ds.
    (x, y) is the coords of the shp_ds's spatial reference
    (the same as the raster_ds's spatial reference),
    the first element represent the x value,
    the second element represent the y value,
    the third element represent the class value.
    The class value is the code of one land cover.

    Parameters:
    --------------
    shp_ds      - shape file(GeometryType is Polygon) Data source
    raster_ds   - raster file Dataset

    Returns:
    -------------
    train_data_coords       - The train data's list contains the  item (x,y, class)

    """

    # 创建一个shp
    train_data_coords = []

    # 获取首个图层
    poly_lyr: ogr.Layer = shp_ds.GetLayer(0)
    # 获取shp文件的坐标系
    shp_osr: osr.SpatialReference = poly_lyr.GetSpatialRef()
    shp_osr.GetAttrValue('AUTHORITY', 1)  # ??? 获取对象属性值

    # 获取Gtiff文件坐标系
    raster_osr = osr.SpatialReference()
    raster_osr.ImportFromWkt(raster_ds.GetProjection())

    # 获取Gtiff文件地理变换
    gtiff_geotrans = raster_ds.GetGeoTransform()

    if raster_osr.GetAttrValue('AUTHORITY', 1) != shp_osr.GetAttrValue(
            'AUTHORITY', 1):
        print(
            'Error: The shape file and the raster file have the differnet spatial refer'
        )
        return train_data_coords

    inv_gt = gdal.InvGeoTransform(gtiff_geotrans)

    # 获取要素的数量
    feat_count = poly_lyr.GetFeatureCount()
    # 保存训练集在栅格上的坐标(X,Y)以及地理编码值(class),需要三列

    # 遍历图层获取每一要素
    for feat_i in range(feat_count):
        # 获取要素
        poly_feat: ogr.Feature = poly_lyr.GetFeature(feat_i)

        # 提取类别编码值
        if 'class' not in poly_feat.keys():
            print("Error: The shape file don't have the 'class' Field")
            break
        name = poly_feat.GetField('class')
        # 从要素中获取多边形几何(是一个环)
        poly_geom: ogr.Geometry = poly_feat.geometry()
        if poly_geom.GetGeometryName() != 'POLYGON':
            print("Error: The geometry type of shape file isn't the polygon.")
            break
        for ring_i in range(poly_geom.GetGeometryCount()):
            # 获取多边形几何的第i个环
            ring: ogr.Geometry = poly_geom.GetGeometryRef(ring_i)

            # 获取几何多边形的边界(西东南北)
            left, right, lower, upper = ring.GetEnvelope()
            points = ring.GetPoints()
            # 判断点在多边形上
            # int OGRPolygon::PointOnSurface(OGRPoint * poPoint) const [virtual]
            for px in np.arange(left, right, gtiff_geotrans[1]):
                for py in np.arange(upper, lower, gtiff_geotrans[5]):
                    # 创建一个点
                    if point_in_poly(px, py, points):
                        offsets = gdal.ApplyGeoTransform(inv_gt, px, py)
                        # 转换为像素坐标(整数值)
                        xoff, yoff = map(int, offsets)
                        train_data_coords.append((xoff, yoff, px, py, name))
                        train_data_coords.append((px, py, name))
    return train_data_coords
Example #12
0
def shp_files_to_csv(shp_files: list, raster_ds: gdal.Dataset,
                     out_csv_file) -> list:
    """
    return a list contain the tuple (x, y, class) of the point contain of polygon shp_ds.
    (x, y) is the coords of the shp_ds's spatial reference
    (the same as the raster_ds's spatial reference),
    the first element represent the x value,
    the second element represent the y value,
    the third element represent the class value.
    The class value is the code of one land cover.

    Parameters:
    --------------
    shp_files   - list of shape file(GeometryType is Polygon) Data source
    raster_ds   - raster file Dataset

    Returns:
    -------------
    train_data_coords       - The train data's list contains the  item (x,y, class)

    """

    # 创建一个shp
    train_data_coords = []

    for i in range(len(shp_files)):
        code = i + 1
        shp_ds = ogr.Open(shp_files[i])
        print(shp_files[i])
        # 获取首个图层
        poly_lyr: ogr.Layer = shp_ds.GetLayer(0)
        # 获取shp文件的坐标系
        shp_osr: osr.SpatialReference = poly_lyr.GetSpatialRef()
        shp_osr.GetAttrValue('AUTHORITY', 1)

        # 获取Gtiff文件坐标系
        raster_osr = osr.SpatialReference()  # 获取地理参考系统
        raster_osr.ImportFromWkt(
            raster_ds.GetProjection())  # 从一个WKT定义的坐标系统来构造一个SpatialReference类对象

        # 获取Gtiff文件地理变换
        gtiff_geotrans = raster_ds.GetGeoTransform()

        if raster_osr.GetAttrValue('AUTHORITY', 1) != shp_osr.GetAttrValue(
                'AUTHORITY', 1):
            print(
                'Error: The shape file and the raster file have the differnet spatial refer'
            )
            return train_data_coords

        inv_gt = gdal.InvGeoTransform(gtiff_geotrans)

        # 获取要素的数量
        feat_count = poly_lyr.GetFeatureCount()
        # 保存训练集在栅格上的坐标(X,Y)以及地理编码值(class),需要三列

        # 遍历图层获取每一要素
        for feat_i in range(feat_count):
            # 获取要素
            poly_feat: ogr.Feature = poly_lyr.GetFeature(feat_i)

            # 没有编码值
            # 从要素中获取多边形几何(是一个环)
            poly_geom: ogr.Geometry = poly_feat.geometry()
            if poly_geom.GetGeometryName() != 'POLYGON':
                print(
                    "Error: The geometry type of shape file isn't the polygon."
                )
                break
            for ring_i in range(poly_geom.GetGeometryCount()):
                # 获取多边形几何的第i个环
                ring: ogr.Geometry = poly_geom.GetGeometryRef(ring_i)

                # 获取几何多边形的边界(西东南北)
                left, right, lower, upper = ring.GetEnvelope()
                points = ring.GetPoints()
                # 判断点在多边形上
                # int OGRPolygon::PointOnSurface(OGRPoint * poPoint) const [virtual]
                for px in np.arange(left, right, gtiff_geotrans[1]):
                    for py in np.arange(upper, lower, gtiff_geotrans[5]):
                        # 创建一个点
                        if point_in_poly(px, py, points):
                            offsets = gdal.ApplyGeoTransform(inv_gt, px, py)
                            # 转换为像素坐标(整数值)
                            xoff, yoff = map(int, offsets)
                            train_data_coords.append(
                                (xoff, yoff, px, py, code))
                            #train_data_coords.append((px, py, code))
        df = pd.DataFrame(train_data_coords,
                          columns=['xoff', 'yoff', 'px', 'py', 'class'])
        df.to_csv(out_csv_file)
        del shp_ds
    print('train.csv have successfully write in.')
    return train_data_coords