コード例 #1
0
ファイル: warplib.py プロジェクト: maubat/pygeotools
def parse_extent(extent, src_ds_list=None, t_srs=None):
    """Parse arbitrary input extent

    Parameters
    ----------
    extent : str or gdal.Dataset or filename or list of float
        Arbitrary input extent
    src_ds_list : list of gdal.Dataset objects, optional
        Needed if specifying 'first', 'last', 'intersection', or 'union'
    t_srs : osr.SpatialReference() object, optional 
        Projection for res calculations

    Returns
    -------
    extent : list of float 
        Output extent [xmin, ymin, xmax, ymax] 
        None if source extent should be preserved
    """

    #Default to using first t_srs for extent calculations
    if t_srs is not None:
        t_srs = parse_srs(t_srs, src_ds_list)

    #Valid strings
    extent_str_list = ['first', 'last', 'intersection', 'union']

    if extent in extent_str_list and src_ds_list is not None:
        if len(src_ds_list) == 1 and (extent == 'intersection'
                                      or extent == 'union'):
            extent = None
        elif extent == 'first':
            extent = geolib.ds_geom_extent(src_ds_list[0], t_srs=t_srs)
            #extent = geolib.ds_extent(src_ds_list[0], t_srs=t_srs)
        elif extent == 'last':
            extent = geolib.ds_geom_extent(src_ds_list[-1], t_srs=t_srs)
            #extent = geolib.ds_extent(src_ds_list[-1], t_srs=t_srs)
        elif extent == 'intersection':
            #By default, compute_intersection takes ref_srs from ref_ds
            extent = geolib.ds_geom_intersection_extent(src_ds_list,
                                                        t_srs=t_srs)
            if len(src_ds_list) > 1 and extent is None:
                sys.exit("Input images do not intersect")
        elif extent == 'union':
            #Need to clean up union t_srs handling
            extent = geolib.ds_geom_union_extent(src_ds_list, t_srs=t_srs)
    elif extent == 'source':
        extent = None
    elif isinstance(extent, gdal.Dataset):
        extent = geolib.ds_geom_extent(extent, t_srs=t_srs)
    elif isinstance(extent, str) and os.path.exists(extent):
        extent = geolib.ds_geom_extent(gdal.Open(extent), t_srs=t_srs)
    elif isinstance(extent, (list, tuple, np.ndarray)):
        extent = list(extent)
    else:
        extent = [float(i) for i in extent.split(' ')]
    return extent
コード例 #2
0
def shp_overlay(ax, ds, shp_fn, gt=None, color='darkgreen'):
    from osgeo import ogr
    from pygeotools.lib import geolib
    #ogr2ogr -f "ESRI Shapefile" output.shp input.shp -clipsrc xmin ymin xmax ymax
    shp_ds = ogr.Open(shp_fn)
    lyr = shp_ds.GetLayer()
    lyr_srs = lyr.GetSpatialRef()
    lyr.ResetReading()
    nfeat = lyr.GetFeatureCount()
    #Note: this is inefficient for large numbers of features
    #Should produce collections of points or lines, then have single plot call
    for n, feat in enumerate(lyr):
        geom = feat.GetGeometryRef()
        geom_type = geom.GetGeometryType()
        #Points
        if geom_type == 1:
            mX, mY, z = geom.GetPoint()
            attr = {'marker':'o', 'markersize':5, 'linestyle':'None'}
        #Line
        elif geom_type == 2:
            l, mX, mY = geolib.line2pts(geom)
            z = 0
            #attr = {'marker':None, 'linestyle':'-', 'linewidth':0.5, 'alpha':0.8}
            attr = {'marker':None, 'linestyle':'-', 'linewidth':1.0, 'alpha':0.8}
            #attr = {'marker':'.', 'markersize':0.5, 'linestyle':'None'}
        #Polygon, placeholder
        #Note: this should be done with the matplotlib patch functionality
        #http://matplotlib.org/users/path_tutorial.html
        elif geom_type == 3:
            print("Polygon support not yet implemented")
            #ogr2ogr -nlt LINESTRING out.shp in.shp
            l, mX, mY = geolib.line2pts(geom)
            z = 0
            attr = {'marker':None, 'linestyle':'-', 'facecolor':'w'}

        ds_srs = geolib.get_ds_srs(ds) 
        if gt is None:
            gt = ds.GetGeoTransform()
        if not lyr_srs.IsSame(ds_srs):
            mX, mY, z = geolib.cT_helper(mX, mY, z, lyr_srs, ds_srs)

        #ds_extent = geolib.ds_extent(ds)
        ds_extent = geolib.ds_geom_extent(ds)
      
        mX = np.ma.array(mX)
        mY = np.ma.array(mY)

        mX[mX < ds_extent[0]] = np.ma.masked
        mX[mX > ds_extent[2]] = np.ma.masked
        mY[mY < ds_extent[1]] = np.ma.masked
        mY[mY > ds_extent[3]] = np.ma.masked

        mask = np.ma.getmaskarray(mY) | np.ma.getmaskarray(mX)
        mX = mX[~mask]
        mY = mY[~mask]

        if mX.count() > 0:
            ax.set_autoscale_on(False)
            if geom_type == 1: 
                pX, pY = geolib.mapToPixel(np.array(mX), np.array(mY), gt)
                ax.plot(pX, pY, color=color, **attr)
            else:
                l = np.ma.array(l)
                l = l[~mask]

                lmed = np.ma.median(np.diff(l))
                lbreaks = (np.diff(l) > lmed*2).nonzero()[0]
                if lbreaks.size: 
                    a = 0
                    lbreaks = list(lbreaks)
                    lbreaks.append(l.size)
                    for b in lbreaks:
                        mmX = mX[a:b+1]
                        mmY = mY[a:b+1]
                        a = b+1
                        #import ipdb; ipdb.set_trace()
                        #pX, pY = geolib.mapToPixel(np.array(mX), np.array(mY), gt)
                        pX, pY = geolib.mapToPixel(mmX, mmY, gt)
                        print(n, np.diff(pX).max(), np.diff(pY).max())
                        #ax.plot(pX, pY, color='LimeGreen', **attr)
                        #ax.plot(pX, pY, color='LimeGreen', alpha=0.5, **attr)
                        #ax.plot(pX, pY, color='w', alpha=0.5, **attr)
                        ax.plot(pX, pY, color=color, **attr)
                else:
                    pX, pY = geolib.mapToPixel(np.array(mX), np.array(mY), gt)
                    ax.plot(pX, pY, color=color, **attr)
コード例 #3
0
ファイル: warplib.py プロジェクト: snowfox1939/pygeotools
def warp(src_ds, res=None, extent=None, t_srs=None, r='cubic', driver=mem_drv, dst_fn=None, dst_ndv=None, verbose=True):
    """Warp an input dataset with predetermined arguments specifying output res/extent/srs

    This is the function that actually calls gdal.ReprojectImage
    
    Parameters
    ----------
    src_ds : gdal.Dataset object
        Dataset to be warped
    res : float
        Desired output resolution
    extent : list of float
        Desired output extent in t_srs coordinate system
    t_srs : osr.SpatialReference()
        Desired output spatial reference
    r : str
        Desired resampling algorithm
    driver : GDAL Driver to use for warp 
        Either MEM or GTiff
    dst_fn : str
        Output filename (for disk warp)
    dst_ndv : float
        Desired output NoData Value

    Returns
    -------
    dst_ds : gdal.Dataset object
        Warped dataset (either in memory or on disk)

    """
    src_srs = geolib.get_ds_srs(src_ds)
    
    if t_srs is None:
        t_srs = geolib.get_ds_srs(src_ds)
    
    src_gt = src_ds.GetGeoTransform()
    #Note: get_res returns [x_res, y_res]
    #Could just use gt here and average x_res and y_res
    src_res = geolib.get_res(src_ds, t_srs=t_srs, square=True)[0]

    if res is None:
        res = src_res

    if extent is None:
        extent = geolib.ds_geom_extent(src_ds, t_srs=t_srs)
    
    #Note: GDAL Lanczos creates block artifacts
    #Wait for gdalwarp to support gaussian resampling
    #Want to use Lanczos for downsampling
    #if src_res < res:
    #    gra = gdal.GRA_Lanczos
    #See http://blog.codinghorror.com/better-image-resizing/
    # Suggests cubic for downsampling, bilinear for upsampling
    #    gra = gdal.GRA_Cubic
    #Cubic for upsampling
    #elif src_res >= res:
    #    gra = gdal.GRA_Bilinear

    gra = parse_rs_alg(r)

    #At this point, the resolution and extent values must be float
    #Extent must be list
    res = float(res)
    extent = [float(i) for i in extent]

    #Might want to move this to memwarp_multi, keep memwarp basic w/ gdal.GRA types

    #Create progress function
    prog_func = None
    if verbose:
        prog_func = gdal.TermProgress
    
    if dst_fn is None:
        #This is a dummy fn if only in mem, but can be accessed later via GetFileList()
        #Actually, no, doesn't look like the filename survivies
        dst_fn = ''
    
    #Compute output image dimensions
    dst_nl = int(round((extent[3] - extent[1])/res))
    dst_ns = int(round((extent[2] - extent[0])/res))
    #dst_nl = int(math.ceil((extent[3] - extent[1])/res))
    #dst_ns = int(math.ceil((extent[2] - extent[0])/res))
    #dst_nl = int(math.floor((extent[3] - extent[1])/res))
    #dst_ns = int(math.floor((extent[2] - extent[0])/res))
    if verbose:
        print('nl: %i ns: %i res: %0.3f' % (dst_nl, dst_ns, res))
    #Create output dataset
    src_b = src_ds.GetRasterBand(1)
    src_dt = src_b.DataType
    src_nl = src_ds.RasterYSize
    src_ns = src_ds.RasterXSize

    dst_ds = driver.Create(dst_fn, dst_ns, dst_nl, src_ds.RasterCount, src_dt) 

    dst_ds.SetProjection(t_srs.ExportToWkt())
    #Might be an issue to use src_gt rotation terms here with arbitrary extent/res
    dst_gt = [extent[0], res, src_gt[2], extent[3], src_gt[4], -res]
    dst_ds.SetGeoTransform(dst_gt)
   
    #This will smooth the input before downsampling to prevent aliasing, fill gaps
    #Pretty inefficent, as we need to create another intermediate dataset
    gauss = False 

    for n in range(1, src_ds.RasterCount+1):
        if dst_ndv is None:
            src_b = src_ds.GetRasterBand(n)
            src_ndv = iolib.get_ndv_b(src_b)
            dst_ndv = src_ndv
        b = dst_ds.GetRasterBand(n)
        b.SetNoDataValue(dst_ndv)
        b.Fill(dst_ndv)

        if gauss:
            from pygeotools.lib import filtlib
            #src_a = src_b.GetVirtualMemArray()
            #Compute resampling ratio to determine filter window size
            res_ratio = float(res)/src_res
            if verbose:
                print("Resampling factor: %0.3f" % res_ratio)
            #Might be more efficient to do iterative gauss filter with size 3, rather than larger windows
            f_size = math.floor(res_ratio/2.)*2+1
            #This is conservative to avoid filling holes with noise
            #f_size = math.floor(res_ratio/2.)*2-1
            if f_size <= 1:
                continue

            if verbose:
                print("Smoothing window size: %i" % f_size)
            #Create temp dataset to store filtered array - avoid overwriting original
            temp_ds = driver.Create('', src_ns, src_nl, src_ds.RasterCount, src_dt) 
            temp_ds.SetProjection(src_srs.ExportToWkt())
            temp_ds.SetGeoTransform(src_gt)
            temp_b = temp_ds.GetRasterBand(n)
            temp_b.SetNoDataValue(dst_ndv)
            temp_b.Fill(dst_ndv)

            src_a = iolib.b_getma(src_b)
            src_a = filtlib.gauss_fltr_astropy(src_a, size=f_size)
            #Want to run with maskfill, so only fills gaps, without expanding isolated points
            temp_b.WriteArray(src_a)
            src_ds = temp_ds
            
            #In theory, NN should be fine since we already smoothed.  In practice, cubic still provides slightly better results
            #gra = gdal.GRA_NearestNeighbour
    
    """
    if not verbose:
        #Suppress GDAL progress bar
        orig_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
    """

    #Note: default maxerror=0.0, second 0.0 argument
    gdal.ReprojectImage(src_ds, dst_ds, src_srs.ExportToWkt(), t_srs.ExportToWkt(), gra, 0.0, 0.0, prog_func)

    """
    if not verbose:
        sys.stdout.close()
        sys.stdout = orig_stdout
    """

    #Note: this is now done in diskwarp
    #Write out to disk
    #if driver != mem_drv:
    #    dst_ds.FlushCache()

    #Return GDAL dataset object in memory
    return dst_ds