Esempio n. 1
0
def get_lulc_ds_warp(ds, lulc_source=None):
    if lulc_source is None:
        lulc_source = get_lulc_source(ds)
    if lulc_source == 'nlcd':
        #Note: want to process LULC with nearest to avoid interpolating values
        rs = 'near'
    else:
        rs = 'cubicspline'
    lulc_ds_full = get_lulc_ds_full(ds, lulc_source)
    lulc_ds_warp = warplib.memwarp_multi([
        lulc_ds_full,
    ],
                                         res=ds,
                                         extent=ds,
                                         t_srs=ds,
                                         r=rs)[0]
    return lulc_ds_warp
Esempio n. 2
0
def main():
    parser = getparser()
    #Create dictionary of arguments
    args = vars(parser.parse_args())
    
    #Want to enable -full when -of is specified, probably a fancy way to do this with argparse
    if args['of']:
        args['full'] = True

    #Note, imshow has many interpolation types:
    #'none', 'nearest', 'bilinear', 'bicubic', 'spline16', 'spline36', 'hanning', 'hamming', 
    #'hermite', 'kaiser', 'quadric', 'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos'
    #{'interpolation':'bicubic', 'aspect':'auto'}
    #args['imshow_kwargs']={'interpolation':'bicubic'}
    args['imshow_kwargs']={'interpolation':'none'}

    if args['clipped'] and args['overlay'] is None:
        sys.exit("Must specify an overlay filename with option 'clipped'")

    #Set this as the background numpy array
    args['bg'] = None

    if args['shp'] is not None:
        print args['shp']

    if args['link']:
        fig = plt.figure(0)
        n_ax = len(args['filelist'])
        src_ds_list = [gdal.Open(fn) for fn in args['filelist']]
        t_srs = geolib.get_ds_srs(src_ds_list[0])
        res_stats = geolib.get_res_stats(src_ds_list, t_srs=t_srs)
        #Use min res
        res = res_stats[0]
        extent = geolib.ds_geom_union_extent(src_ds_list, t_srs=t_srs)
        #print res, extent

    for n,fn in enumerate(args['filelist']):

        if not iolib.fn_check(fn):
            print 'Unable to open input file: %s' % fn
            continue

        #Note: this won't work if img1 has 1 band and img2 has 3 bands
        #Hack for now
        if not args['link']:
            fig = plt.figure(n)
            n_ax = 1
        
        #fig.set_facecolor('black')
        fig.set_facecolor('white')
        fig.canvas.set_window_title(os.path.split(fn)[1])
        #fig.suptitle(os.path.split(fn)[1], fontsize=10)

        #Note: warplib SHOULD internally check to see if extent/resolution/projection are identical
        #This eliminates the need for a clipped flag
        #If user has already warped the background and source data 
        if args['overlay']:
            if args['clipped']: 
                src_ds = gdal.Open(fn, gdal.GA_ReadOnly)
                #Only load up the bg array once
                if args['bg'] is None:
                    #Need to check that background fn exists
                    print "%s background" % args['overlay']
                    bg_ds = gdal.Open(args['overlay'], gdal.GA_ReadOnly)
                    #Check image dimensions
                    args['bg'] = get_bma(bg_ds, 1, args['full'])
            else:
                #Clip/warp background dataset to match overlay dataset 
                #src_ds, bg_ds = warplib.memwarp_multi_fn([fn, args['overlay']], extent='union')
                src_ds, bg_ds = warplib.memwarp_multi_fn([fn, args['overlay']], extent='first')
                #src_ds, bg_ds = warplib.memwarp_multi_fn([fn, args['overlay']], res='min', extent='first')
                #Want to load up the unique bg array for each input
                args['bg'] = get_bma(bg_ds, 1, args['full'])
        else:
            src_ds = gdal.Open(fn, gdal.GA_ReadOnly)
            if args['link']:
                #Not sure why, but this still warps all linked ds, even when identical res/extent/srs
                #src_ds = warplib.warp(src_ds, res=res, extent=extent, t_srs=t_srs)
                src_ds = warplib.memwarp_multi([src_ds,], res=res, extent=extent, t_srs=t_srs)[0]

        cbar_kwargs={'extend':'both', 'orientation':'vertical', 'shrink':0.7, 'fraction':0.12, 'pad':0.02}

        nbands = src_ds.RasterCount
        b = src_ds.GetRasterBand(1)
        dt = gdal.GetDataTypeName(b.DataType)
        #Eventually, check dt of each band
        print 
        print "%s (%i bands)" % (fn, nbands)
        #Singleband raster
        if (nbands == 1):
            if args['cmap'] is None:
                #Special case to handle ASP float32 grayscale data
                if '-L_sub' in fn or '-R_sub' in fn:
                    args['cmap'] = 'gray'
                else:
                    if (dt == 'Float64') or (dt == 'Float32') or (dt == 'Int32'):
                        args['cmap'] = 'cpt_rainbow'
                    #This is for WV images
                    elif (dt == 'UInt16'):
                        args['cmap'] = 'gray'
                    elif (dt == 'Byte'):
                        args['cmap'] = 'gray'
                    else:
                        args['cmap'] = 'cpt_rainbow'
                """
                if 'count' in fn:
                    args['clim_perc'] = (0,100)
                    cbar_kwargs['extend'] = 'neither'
                    args['cmap'] = 'cpt_rainbow'
                if 'mask' in fn:
                    args['clim'] = (0, 1)
                    #Could be (0, 255)
                    #args['clim_perc'] = (0,100)
                    #Want absolute clim of 0, then perc of 100
                    cbar_kwargs['extend'] = 'neither'
                    args['cmap'] = 'gray'
                """
            args['cbar_kwargs'] = cbar_kwargs
            bma = get_bma(src_ds, 1, args['full'])   
            #Note n+1 here ensures we're assigning subplot correctly here (n is 0-relative, subplot is 1)
            bma_fig(fig, bma, n_subplt=n_ax, subplt=n+1, ds=src_ds, **args)
        #3-band raster, likely disparity map
        #This doesn't work when alpha band is present
        elif (nbands == 3) and (dt == 'Byte'):
            #For some reason, tifs are vertically flipped
            if (os.path.splitext(fn)[1] == '.tif'):
                args['imshow_kwargs']['origin'] = 'lower'
            #Use gdal dataset here instead of imread(fn)?
            imgplot = plt.imshow(plt.imread(fn), **args['imshow_kwargs'])
            pltlib.hide_ticks(imgplot.axes)
        #Handle the 3-band disparity map case here
        #elif ((dt == 'Float32') or (dt == 'Int32')):
        else: 
            if args['cmap'] is None:
                args['cmap'] = 'cpt_rainbow'
            bn = 1
            while bn <= nbands:
                bma = get_bma(src_ds, bn, args['full'])
                bma_fig(fig, bma, n_subplt=nbands, subplt=bn, ds=src_ds, **args)
                bn += 1
        #Want to be better about this else case - lazy for now
        #else:
        #    bma = get_bma(src_ds, 1, args['full'])
        #    bma_fig(fig, bma, **args)

        ts = timelib.fn_getdatetime_list(fn) 

        if ts:
            print "Timestamp list: ", ts

        """
        if len(ts) == 1:
            plt.title(ts[0].date())
        elif len(ts) == 2:
            plt.title("%s to %s" % (ts[0].date(), ts[1].date()))
        """
            
        plt.tight_layout()
        
        #Write out the file 
        #Note: make sure display is local for savefig
        if args['of']:
            outf = str(os.path.splitext(fn)[0])+'_fig.'+args['of'] 
            #outf = str(os.path.splitext(fn)[0])+'_'+str(os.path.splitext(args['overlay'])[0])+'_fig.'+args['of'] 

            #Note: need to account for colorbar (12%) and title - some percentage of axes beyond bma dimensions
            #Should specify minimum text size for output

            max_size = np.array((10.0,10.0))
            max_dpi = 300.0
            #If both outsize and dpi are specified, don't try to change, just make the figure
            if (args['outsize'] is None) and (args['dpi'] is None):
                args['dpi'] = 150.0

            #Unspecified out figure size for a given dpi
            if (args['outsize'] is None) and (args['dpi'] is not None):
                args['outsize'] = np.array(bma.shape[::-1])/args['dpi']
                if np.any(np.array(args['outsize']) > max_size):
                    args['outsize'] = max_size
            #Specified output figure size, no specified dpi 
            elif (args['outsize'] is not None) and (args['dpi'] is None):
                args['dpi'] = np.min([np.max(np.array(bma.shape[::-1])/np.array(args['outsize'])), max_dpi])
                
            print
            print "Saving output figure:"
            print "Filename: ", outf
            print "Size (in): ", args['outsize']
            print "DPI (px/in): ", args['dpi']
            print "Input dimensions (px): ", bma.shape[::-1]
            print "Output dimensions (px): ", tuple(np.array(args['outsize'])*args['dpi'])
            print

            fig.set_size_inches(args['outsize'])
            #fig.set_size_inches(54.427, 71.87)
            #fig.set_size_inches(40, 87)
            fig.savefig(outf, dpi=args['dpi'], bbox_inches='tight', pad_inches=0, facecolor=fig.get_facecolor(), edgecolor='none')
    #Show the plot - want to show all at once
    if not args['of']: 
        plt.show()
Esempio n. 3
0
def get_mask(dem_ds,
             mask_list,
             dem_fn=None,
             writeout=False,
             outdir=None,
             args=None):
    mask_list = check_mask_list(mask_list)
    if not mask_list or 'none' in mask_list:
        newmask = False
    else:
        #Basename for output files
        if outdir is not None:
            if not os.path.exists(outdir):
                os.makedirs(outdir)
        else:
            outdir = os.path.split(os.path.realpath(dem_fn))[0]

        if dem_fn is not None:
            #Extract DEM timestamp
            dem_dt = timelib.fn_getdatetime(dem_fn)
            out_fn_base = os.path.join(
                outdir,
                os.path.splitext(os.path.split(dem_fn)[-1])[0])

        if args is None:
            #Get default values
            parser = getparser()
            args = parser.parse_args([
                '',
            ])

        newmask = True

        if 'glaciers' in mask_list:
            icemask = get_icemask(dem_ds)
            if writeout:
                out_fn = out_fn_base + '_ice_mask.tif'
                print("Writing out %s" % out_fn)
                iolib.writeGTiff(icemask, out_fn, src_ds=dem_ds)
            newmask = np.logical_and(icemask, newmask)

        #Need to process NLCD separately, with nearest neighbor inteprolatin
        if 'nlcd' in mask_list and args.nlcd_filter != 'none':
            rs = 'near'
            nlcd_ds = gdal.Open(get_nlcd_fn())
            nlcd_ds_warp = warplib.memwarp_multi([
                nlcd_ds,
            ],
                                                 res=dem_ds,
                                                 extent=dem_ds,
                                                 t_srs=dem_ds,
                                                 r=rs)[0]
            out_fn = None
            if writeout:
                out_fn = out_fn_base + '_nlcd.tif'
            nlcdmask = get_nlcd_mask(nlcd_ds_warp,
                                     filter=args.nlcd_filter,
                                     out_fn=out_fn)
            if writeout:
                out_fn = os.path.splitext(out_fn)[0] + '_mask.tif'
                print("Writing out %s" % out_fn)
                iolib.writeGTiff(nlcdmask, out_fn, src_ds=dem_ds)
            newmask = np.logical_and(nlcdmask, newmask)

        if 'bareground' in mask_list and args.bareground_thresh > 0:
            bareground_ds = gdal.Open(get_bareground_fn())
            bareground_ds_warp = warplib.memwarp_multi([
                bareground_ds,
            ],
                                                       res=dem_ds,
                                                       extent=dem_ds,
                                                       t_srs=dem_ds,
                                                       r='cubicspline')[0]
            out_fn = None
            if writeout:
                out_fn = out_fn_base + '_bareground.tif'
            baregroundmask = get_bareground_mask(
                bareground_ds_warp,
                bareground_thresh=args.bareground_thresh,
                out_fn=out_fn)
            if writeout:
                out_fn = os.path.splitext(out_fn)[0] + '_mask.tif'
                print("Writing out %s" % out_fn)
                iolib.writeGTiff(baregroundmask, out_fn, src_ds=dem_ds)
            newmask = np.logical_and(baregroundmask, newmask)

        if 'snodas' in mask_list and args.snodas_thresh > 0:
            #Get SNODAS snow depth products for DEM timestamp
            snodas_min_dt = datetime(2003, 9, 30)
            if dem_dt >= snodas_min_dt:
                snodas_ds = get_snodas_ds(dem_dt)
                if snodas_ds is not None:
                    snodas_ds_warp = warplib.memwarp_multi([
                        snodas_ds,
                    ],
                                                           res=dem_ds,
                                                           extent=dem_ds,
                                                           t_srs=dem_ds,
                                                           r='cubicspline')[0]
                    #snow depth values are mm, convert to meters
                    snodas_depth = iolib.ds_getma(snodas_ds_warp) / 1000.
                    if snodas_depth.count() > 0:
                        print(
                            "Applying SNODAS snow depth filter (masking values >= %0.2f m)"
                            % args.snodas_thresh)
                        out_fn = None
                        if writeout:
                            out_fn = out_fn_base + '_snodas_depth.tif'
                            print("Writing out %s" % out_fn)
                            iolib.writeGTiff(snodas_depth,
                                             out_fn,
                                             src_ds=dem_ds)
                        snodas_mask = np.ma.masked_greater(
                            snodas_depth, args.snodas_thresh)
                        snodas_mask = ~(np.ma.getmaskarray(snodas_mask))
                        if writeout:
                            out_fn = os.path.splitext(out_fn)[0] + '_mask.tif'
                            print("Writing out %s" % out_fn)
                            iolib.writeGTiff(snodas_mask,
                                             out_fn,
                                             src_ds=dem_ds)
                        newmask = np.logical_and(snodas_mask, newmask)
                    else:
                        print(
                            "SNODAS grid for input location and timestamp is empty"
                        )

        #These tiles cover CONUS
        #tile_list=('h08v04', 'h09v04', 'h10v04', 'h08v05', 'h09v05')
        if 'modscag' in mask_list and args.modscag_thresh > 0:
            modscag_min_dt = datetime(2000, 2, 24)
            if dem_dt < modscag_min_dt:
                print("Warning: DEM timestamp (%s) is before earliest MODSCAG timestamp (%s)" \
                        % (dem_dt, modscag_min_dt))
            else:
                tile_list = get_modis_tile_list(dem_ds)
                print(tile_list)
                pad_days = 7
                modscag_fn_list = get_modscag_fn_list(dem_dt,
                                                      tile_list=tile_list,
                                                      pad_days=pad_days)
                if modscag_fn_list:
                    modscag_ds = proc_modscag(modscag_fn_list,
                                              extent=dem_ds,
                                              t_srs=dem_ds)
                    modscag_ds_warp = warplib.memwarp_multi([
                        modscag_ds,
                    ],
                                                            res=dem_ds,
                                                            extent=dem_ds,
                                                            t_srs=dem_ds,
                                                            r='cubicspline')[0]
                    print(
                        "Applying MODSCAG fractional snow cover percent filter (masking values >= %0.1f%%)"
                        % args.modscag_thresh)
                    modscag_fsca = iolib.ds_getma(modscag_ds_warp)
                    out_fn = None
                    if writeout:
                        out_fn = out_fn_base + '_modscag_fsca.tif'
                        print("Writing out %s" % out_fn)
                        iolib.writeGTiff(modscag_fsca, out_fn, src_ds=dem_ds)
                    modscag_mask = (modscag_fsca.filled(0) >=
                                    args.modscag_thresh)
                    modscag_mask = ~(modscag_mask)
                    if writeout:
                        out_fn = os.path.splitext(out_fn)[0] + '_mask.tif'
                        print("Writing out %s" % out_fn)
                        iolib.writeGTiff(modscag_mask, out_fn, src_ds=dem_ds)
                    newmask = np.logical_and(modscag_mask, newmask)

        #Use reflectance values to estimate snowcover
        if 'toa' in mask_list:
            #Use top of atmosphere scaled reflectance values (0-1)
            toa_ds = gdal.Open(get_toa_fn(dem_fn))
            toa_ds_warp = warplib.memwarp_multi([
                toa_ds,
            ],
                                                res=dem_ds,
                                                extent=dem_ds,
                                                t_srs=dem_ds)[0]
            toa_mask = get_toa_mask(toa_ds_warp, args.toa_thresh)
            if writeout:
                out_fn = out_fn_base + '_toa_mask.tif'
                print("Writing out %s" % out_fn)
                iolib.writeGTiff(toa_mask, out_fn, src_ds=dem_ds)
            newmask = np.logical_and(toa_mask, newmask)

        if False:
            #Filter based on expected snowline
            #Simplest approach uses altitude cutoff
            max_elev = 1500
            newdem = np.ma.masked_greater(dem, max_elev)
            newmask = np.ma.getmaskarray(newdem)

        print(
            "Generating final mask to use for reference surfaces, and applying to input DEM"
        )
        #Now invert to use to create final masked array
        #True (1) represents "invalid" pixel to match numpy ma convetion
        newmask = ~newmask

        #Dilate the mask
        if args.dilate is not None:
            niter = args.dilate
            print("Dilating mask with %i iterations" % niter)
            from scipy import ndimage
            newmask = ~(ndimage.morphology.binary_dilation(~newmask,
                                                           iterations=niter))

    return newmask
Esempio n. 4
0
def compute_offset(ref_dem_ds, src_dem_ds, src_dem_fn, mode='nuth', remove_outliers=True, max_offset=100, \
        max_dz=100, slope_lim=(0.1, 40), mask_list=['glaciers',], plot=True):
    #Make sure the input datasets have the same resolution/extent
    #Use projection of source DEM
    ref_dem_clip_ds, src_dem_clip_ds = warplib.memwarp_multi([ref_dem_ds, src_dem_ds], \
            res='max', extent='intersection', t_srs=src_dem_ds, r='cubic')

    #Compute size of NCC and SAD search window in pixels
    res = float(geolib.get_res(ref_dem_clip_ds, square=True)[0])
    max_offset_px = (max_offset/res) + 1
    #print(max_offset_px)
    pad = (int(max_offset_px), int(max_offset_px))

    #This will be updated geotransform for src_dem
    src_dem_gt = np.array(src_dem_clip_ds.GetGeoTransform())

    #Load the arrays
    ref_dem = iolib.ds_getma(ref_dem_clip_ds, 1)
    src_dem = iolib.ds_getma(src_dem_clip_ds, 1)

    print("Elevation difference stats for uncorrected input DEMs (src - ref)")
    diff = src_dem - ref_dem

    static_mask = get_mask(src_dem_clip_ds, mask_list, src_dem_fn)
    diff = np.ma.array(diff, mask=static_mask)

    if diff.count() == 0:
        sys.exit("No overlapping, unmasked pixels shared between input DEMs")

    if remove_outliers:
        diff = outlier_filter(diff, f=3, max_dz=max_dz)

    #Want to use higher quality DEM, should determine automatically from original res/count
    #slope = get_filtered_slope(ref_dem_clip_ds, slope_lim=slope_lim)
    slope = get_filtered_slope(src_dem_clip_ds, slope_lim=slope_lim)

    print("Computing aspect")
    #aspect = geolib.gdaldem_mem_ds(ref_dem_clip_ds, processing='aspect', returnma=True, computeEdges=False)
    aspect = geolib.gdaldem_mem_ds(src_dem_clip_ds, processing='aspect', returnma=True, computeEdges=False)

    ref_dem_clip_ds = None
    src_dem_clip_ds = None

    #Apply slope filter to diff
    #Note that we combine masks from diff and slope in coreglib
    diff = np.ma.array(diff, mask=np.ma.getmaskarray(slope))

    #Get final mask after filtering
    static_mask = np.ma.getmaskarray(diff)

    #Compute stats for new masked difference map
    print("Filtered difference map")
    diff_stats = malib.print_stats(diff)
    dz = diff_stats[5]

    print("Computing sub-pixel offset between DEMs using mode: %s" % mode)

    #By default, don't create output figure
    fig = None

    #Default horizntal shift is (0,0)
    dx = 0
    dy = 0

    #Sum of absolute differences
    if mode == "sad":
        ref_dem = np.ma.array(ref_dem, mask=static_mask)
        src_dem = np.ma.array(src_dem, mask=static_mask)
        m, int_offset, sp_offset = coreglib.compute_offset_sad(ref_dem, src_dem, pad=pad)
        #Geotransform has negative y resolution, so don't need negative sign
        #np array is positive down
        #GDAL coordinates are positive up
        dx = sp_offset[1]*src_dem_gt[1]
        dy = sp_offset[0]*src_dem_gt[5]
    #Normalized cross-correlation of clipped, overlapping areas
    elif mode == "ncc":
        ref_dem = np.ma.array(ref_dem, mask=static_mask)
        src_dem = np.ma.array(src_dem, mask=static_mask)
        m, int_offset, sp_offset, fig = coreglib.compute_offset_ncc(ref_dem, src_dem, \
                pad=pad, prefilter=False, plot=plot)
        dx = sp_offset[1]*src_dem_gt[1]
        dy = sp_offset[0]*src_dem_gt[5]
    #Nuth and Kaab (2011)
    elif mode == "nuth":
        #Compute relationship between elevation difference, slope and aspect
        fit_param, fig = coreglib.compute_offset_nuth(diff, slope, aspect, plot=plot)
        if fit_param is None:
            print("Failed to calculate horizontal shift")
        else:
            #fit_param[0] is magnitude of shift vector
            #fit_param[1] is direction of shift vector
            #fit_param[2] is mean bias divided by tangent of mean slope
            #print(fit_param)
            dx = fit_param[0]*np.sin(np.deg2rad(fit_param[1]))
            dy = fit_param[0]*np.cos(np.deg2rad(fit_param[1]))
            med_slope = malib.fast_median(slope)
            nuth_dz = fit_param[2]*np.tan(np.deg2rad(med_slope))
            print('Median dz: %0.2f\nNuth dz: %0.2f' % (dz, nuth_dz))
            #dz = nuth_dz
    elif mode == "all":
        print("Not yet implemented")
        #Want to compare all methods, average offsets
        #m, int_offset, sp_offset = coreglib.compute_offset_sad(ref_dem, src_dem)
        #m, int_offset, sp_offset = coreglib.compute_offset_ncc(ref_dem, src_dem)
    elif mode == "none":
        print("Skipping alignment, writing out DEM with median bias over static surfaces removed")
        dst_fn = outprefix+'_med%0.1f.tif' % dz
        iolib.writeGTiff(src_dem_orig + dz, dst_fn, src_dem_ds)
        sys.exit()
    #Note: minus signs here since we are computing dz=(src-ref), but adjusting src
    return -dx, -dy, -dz, static_mask, fig
Esempio n. 5
0
def main(argv=None):
    parser = getparser()
    args = parser.parse_args()

    #Should check that files exist
    ref_dem_fn = args.ref_fn
    src_dem_fn = args.src_fn

    mode = args.mode
    mask_list = args.mask_list
    max_offset = args.max_offset
    max_dz = args.max_dz
    slope_lim = tuple(args.slope_lim)
    tiltcorr = args.tiltcorr
    polyorder = args.polyorder
    res = args.res

    #Maximum number of iterations
    max_iter = args.max_iter

    #These are tolerances (in meters) to stop iteration
    tol = args.tol
    min_dx = tol
    min_dy = tol
    min_dz = tol

    outdir = args.outdir
    if outdir is None:
        outdir = os.path.splitext(src_dem_fn)[0] + '_dem_align'

    if tiltcorr:
        outdir += '_tiltcorr'
        tiltcorr_done = False
        #Relax tolerance for initial round of co-registration
        #tiltcorr_tol = 0.1
        #if tol < tiltcorr_tol:
        #    tol = tiltcorr_tol

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    outprefix = '%s_%s' % (os.path.splitext(os.path.split(src_dem_fn)[-1])[0], \
            os.path.splitext(os.path.split(ref_dem_fn)[-1])[0])
    outprefix = os.path.join(outdir, outprefix)

    print("\nReference: %s" % ref_dem_fn)
    print("Source: %s" % src_dem_fn)
    print("Mode: %s" % mode)
    print("Output: %s\n" % outprefix)

    src_dem_ds = gdal.Open(src_dem_fn)
    ref_dem_ds = gdal.Open(ref_dem_fn)

    #Get local cartesian coordinate system
    #local_srs = geolib.localtmerc_ds(src_dem_ds)
    #Use original source dataset coordinate system
    #Potentially issues with distortion and xyz/tiltcorr offsets for DEM with large extent
    local_srs = geolib.get_ds_srs(src_dem_ds)
    #local_srs = geolib.get_ds_srs(ref_dem_ds)

    #Resample to common grid
    ref_dem_res = float(geolib.get_res(ref_dem_ds, t_srs=local_srs, square=True)[0])
    #Create a copy to be updated in place
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    src_dem_res = float(geolib.get_res(src_dem_ds, t_srs=local_srs, square=True)[0])
    src_dem_ds = None
    #Resample to user-specified resolution
    ref_dem_ds, src_dem_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align], \
            extent='intersection', res=args.res, t_srs=local_srs, r='cubic')

    res = float(geolib.get_res(src_dem_ds_align, square=True)[0])
    print("\nReference DEM res: %0.2f" % ref_dem_res)
    print("Source DEM res: %0.2f" % src_dem_res)
    print("Resolution for coreg: %s (%0.2f m)\n" % (args.res, res))

    #Iteration number
    n = 1
    #Cumulative offsets
    dx_total = 0
    dy_total = 0
    dz_total = 0

    #Now iteratively update geotransform and vertical shift
    while True:
        print("*** Iteration %i ***" % n)
        dx, dy, dz, static_mask, fig = compute_offset(ref_dem_ds, src_dem_ds_align, src_dem_fn, mode, max_offset, \
                mask_list=mask_list, max_dz=max_dz, slope_lim=slope_lim, plot=True)
        xyz_shift_str_iter = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx, dy, dz)
        print("Incremental offset: %s" % xyz_shift_str_iter)

        dx_total += dx
        dy_total += dy
        dz_total += dz

        xyz_shift_str_cum = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx_total, dy_total, dz_total)
        print("Cumulative offset: %s" % xyz_shift_str_cum)
        #String to append to output filenames
        xyz_shift_str_cum_fn = '_%s_x%+0.2f_y%+0.2f_z%+0.2f' % (mode, dx_total, dy_total, dz_total)

        #Should make an animation of this converging
        if n == 1: 
            #static_mask_orig = static_mask
            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental: %s\nCumulative: %s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

        #Apply the horizontal shift to the original dataset
        src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx, dy, createcopy=False)
        #Should 
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz, createcopy=False)

        n += 1
        print("\n")
        #If magnitude of shift in all directions is less than tol
        #if n > max_iter or (abs(dx) <= min_dx and abs(dy) <= min_dy and abs(dz) <= min_dz):
        #If magnitude of shift is less than tol
        dm = np.sqrt(dx**2 + dy**2 + dz**2)
        dm_total = np.sqrt(dx_total**2 + dy_total**2 + dz_total**2)

        if dm_total > max_offset:
            sys.exit("Total offset exceeded specified max_offset (%0.2f m). Consider increasing -max_offset argument" % max_offset)

        #Stop iteration
        if n > max_iter or dm < tol:

            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental:%s\nCumulative:%s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

            #Compute final elevation difference
            if True:
                ref_dem_clip_ds_align, src_dem_clip_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align], \
                        res=res, extent='intersection', t_srs=local_srs, r='cubic')
                ref_dem_align = iolib.ds_getma(ref_dem_clip_ds_align, 1)
                src_dem_align = iolib.ds_getma(src_dem_clip_ds_align, 1)
                ref_dem_clip_ds_align = None

                diff_align = src_dem_align - ref_dem_align
                src_dem_align = None
                ref_dem_align = None

                #Get updated, final mask
                static_mask_final = get_mask(src_dem_clip_ds_align, mask_list, src_dem_fn)
                static_mask_final = np.logical_or(np.ma.getmaskarray(diff_align), static_mask_final)
                
                #Final stats, before outlier removal
                diff_align_compressed = diff_align[~static_mask_final]
                diff_align_stats = malib.get_stats_dict(diff_align_compressed, full=True)

                #Prepare filtered version for tiltcorr fit
                diff_align_filt = np.ma.array(diff_align, mask=static_mask_final)
                diff_align_filt = outlier_filter(diff_align_filt, f=3, max_dz=max_dz)
                #diff_align_filt = outlier_filter(diff_align_filt, perc=(12.5, 87.5), max_dz=max_dz)
                slope = get_filtered_slope(src_dem_clip_ds_align)
                diff_align_filt = np.ma.array(diff_align_filt, mask=np.ma.getmaskarray(slope))
                diff_align_filt_stats = malib.get_stats_dict(diff_align_filt, full=True)

            #Fit 2D polynomial to residuals and remove
            #To do: add support for along-track and cross-track artifacts
            if tiltcorr and not tiltcorr_done:
                print("\n************")
                print("Calculating 'tiltcorr' 2D polynomial fit to residuals with order %i" % polyorder)
                print("************\n")
                gt = src_dem_clip_ds_align.GetGeoTransform()

                #Need to apply the mask here, so we're only fitting over static surfaces
                #Note that the origmask=False will compute vals for all x and y indices, which is what we want
                vals, resid, coeff = geolib.ma_fitpoly(diff_align_filt, order=polyorder, gt=gt, perc=(0,100), origmask=False)
                #vals, resid, coeff = geolib.ma_fitplane(diff_align_filt, gt, perc=(12.5, 87.5), origmask=False)

                #Should write out coeff or grid with correction 

                vals_stats = malib.get_stats_dict(vals)

                #Want to have max_tilt check here
                #max_tilt = 4.0 #m
                #Should do percentage
                #vals.ptp() > max_tilt

                #Note: dimensions of ds and vals will be different as vals are computed for clipped intersection
                #Need to recompute planar offset for full src_dem_ds_align extent and apply
                xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
                valgrid = geolib.polyval2d(xgrid, ygrid, coeff) 
                #For results of ma_fitplane
                #valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
                src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)

                if True:
                    print("Creating plot of polynomial fit to residuals")
                    fig, axa = plt.subplots(1,2, figsize=(8, 4))
                    dz_clim = malib.calcperc_sym(vals, (2, 98))
                    ax = pltlib.iv(diff_align_filt, ax=axa[0], cmap='RdBu', clim=dz_clim, \
                            label='Residual dz (m)', scalebar=False)
                    ax = pltlib.iv(valgrid, ax=axa[1], cmap='RdBu', clim=dz_clim, \
                            label='Polyfit dz (m)', ds=src_dem_ds_align)
                    #if tiltcorr:
                        #xyz_shift_str_cum_fn += "_tiltcorr"
                    tiltcorr_fig_fn = outprefix + '%s_polyfit.png' % xyz_shift_str_cum_fn
                    print("Writing out figure: %s\n" % tiltcorr_fig_fn)
                    fig.savefig(tiltcorr_fig_fn, dpi=300)

                print("Applying tilt correction to difference map")
                diff_align -= vals

                #Should iterate until tilts are below some threshold
                #For now, only do one tiltcorr
                tiltcorr_done=True
                #Now use original tolerance, and number of iterations 
                tol = args.tol
                max_iter = n + args.max_iter
            else:
                break

    if True:
        #Write out aligned difference map for clipped extent with vertial offset removed
        align_diff_fn = outprefix + '%s_align_diff.tif' % xyz_shift_str_cum_fn
        print("Writing out aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align, align_diff_fn, src_dem_clip_ds_align)

    if True:
        #Write out fitered aligned difference map
        align_diff_filt_fn = outprefix + '%s_align_diff_filt.tif' % xyz_shift_str_cum_fn
        print("Writing out filtered aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align_filt, align_diff_filt_fn, src_dem_clip_ds_align)

    #Extract final center coordinates for intersection
    center_coord_ll = geolib.get_center(src_dem_clip_ds_align, t_srs=geolib.wgs_srs)
    center_coord_xy = geolib.get_center(src_dem_clip_ds_align)
    src_dem_clip_ds_align = None

    #Write out final aligned src_dem 
    align_fn = outprefix + '%s_align.tif' % xyz_shift_str_cum_fn
    print("Writing out shifted src_dem with median vertical offset removed: %s" % align_fn)
    #Open original uncorrected dataset at native resolution
    src_dem_ds = gdal.Open(src_dem_fn)
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    #Apply final horizontal and vertial shift to the original dataset
    #Note: potentially issues if we used a different projection during coregistration!
    src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx_total, dy_total, createcopy=False)
    src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz_total, createcopy=False)
    if tiltcorr:
        xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
        valgrid = geolib.polyval2d(xgrid, ygrid, coeff) 
        #For results of ma_fitplane
        #valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)
    #Might be cleaner way to write out MEM ds directly to disk
    src_dem_full_align = iolib.ds_getma(src_dem_ds_align)
    iolib.writeGTiff(src_dem_full_align, align_fn, src_dem_ds_align)

    if True:
        #Output final aligned src_dem, masked so only best pixels are preserved
        #Useful if creating a new reference product
        #Can also use apply_mask.py 
        print("Applying filter to shiftec src_dem")
        align_diff_filt_full_ds = warplib.memwarp_multi_fn([align_diff_filt_fn,], res=src_dem_ds_align, extent=src_dem_ds_align, \
                t_srs=src_dem_ds_align)[0]
        align_diff_filt_full = iolib.ds_getma(align_diff_filt_full_ds)
        align_diff_filt_full_ds = None
        align_fn_masked = outprefix + '%s_align_filt.tif' % xyz_shift_str_cum_fn
        iolib.writeGTiff(np.ma.array(src_dem_full_align, mask=np.ma.getmaskarray(align_diff_filt_full)), \
                align_fn_masked, src_dem_ds_align)

    src_dem_full_align = None
    src_dem_ds_align = None

    #Compute original elevation difference
    if True:
        ref_dem_clip_ds, src_dem_clip_ds = warplib.memwarp_multi([ref_dem_ds, src_dem_ds], \
                res=res, extent='intersection', t_srs=local_srs, r='cubic')
        src_dem_ds = None
        ref_dem_ds = None
        ref_dem_orig = iolib.ds_getma(ref_dem_clip_ds)
        src_dem_orig = iolib.ds_getma(src_dem_clip_ds)
        #Needed for plotting
        ref_dem_hs = geolib.gdaldem_mem_ds(ref_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        src_dem_hs = geolib.gdaldem_mem_ds(src_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        diff_orig = src_dem_orig - ref_dem_orig
        #Only compute stats over valid surfaces
        static_mask_orig = get_mask(src_dem_clip_ds, mask_list, src_dem_fn)
        #Note: this doesn't include outlier removal or slope mask!
        static_mask_orig = np.logical_or(np.ma.getmaskarray(diff_orig), static_mask_orig)
        #For some reason, ASTER DEM diff have a spike near the 0 bin, could be an issue with masking?
        diff_orig_compressed = diff_orig[~static_mask_orig]
        diff_orig_stats = malib.get_stats_dict(diff_orig_compressed, full=True)

        #Prepare filtered version for comparison 
        diff_orig_filt = np.ma.array(diff_orig, mask=static_mask_orig)
        diff_orig_filt = outlier_filter(diff_orig_filt, f=3, max_dz=max_dz)
        #diff_orig_filt = outlier_filter(diff_orig_filt, perc=(12.5, 87.5), max_dz=max_dz)
        slope = get_filtered_slope(src_dem_clip_ds)
        diff_orig_filt = np.ma.array(diff_orig_filt, mask=np.ma.getmaskarray(slope))
        diff_orig_filt_stats = malib.get_stats_dict(diff_orig_filt, full=True)

        #Write out original difference map
        print("Writing out original difference map for common intersection before alignment")
        orig_diff_fn = outprefix + '_orig_diff.tif'
        iolib.writeGTiff(diff_orig, orig_diff_fn, ref_dem_clip_ds)
        src_dem_clip_ds = None
        ref_dem_clip_ds = None

    if True:
        align_stats_fn = outprefix + '%s_align_stats.json' % xyz_shift_str_cum_fn
        align_stats = {}
        align_stats['src_fn'] = src_dem_fn 
        align_stats['ref_fn'] = ref_dem_fn 
        align_stats['align_fn'] = align_fn 
        align_stats['res'] = {} 
        align_stats['res']['src'] = src_dem_res
        align_stats['res']['ref'] = ref_dem_res
        align_stats['res']['coreg'] = res
        align_stats['center_coord'] = {'lon':center_coord_ll[0], 'lat':center_coord_ll[1], \
                'x':center_coord_xy[0], 'y':center_coord_xy[1]}
        align_stats['shift'] = {'dx':dx_total, 'dy':dy_total, 'dz':dz_total, 'dm':dm_total}
        #This tiltcorr flag gets set to false, need better flag
        if tiltcorr:
            align_stats['tiltcorr'] = {}
            align_stats['tiltcorr']['coeff'] = coeff.tolist()
            align_stats['tiltcorr']['val_stats'] = vals_stats
        align_stats['before'] = diff_orig_stats
        align_stats['before_filt'] = diff_orig_filt_stats
        align_stats['after'] = diff_align_stats
        align_stats['after_filt'] = diff_align_filt_stats
        
        import json
        with open(align_stats_fn, 'w') as f:
            json.dump(align_stats, f)

    #Create output plot
    if True:
        print("Creating final plot")
        kwargs = {'interpolation':'none'}
        #f, axa = plt.subplots(2, 4, figsize=(11, 8.5))
        f, axa = plt.subplots(2, 4, figsize=(16, 8))
        for ax in axa.ravel()[:-1]:
            ax.set_facecolor('k')
            pltlib.hide_ticks(ax)
        dem_clim = malib.calcperc(ref_dem_orig, (2,98))
        axa[0,0].imshow(ref_dem_hs, cmap='gray', **kwargs)
        im = axa[0,0].imshow(ref_dem_orig, cmap='cpt_rainbow', clim=dem_clim, alpha=0.6, **kwargs)
        pltlib.add_cbar(axa[0,0], im, arr=ref_dem_orig, clim=dem_clim, label=None)
        pltlib.add_scalebar(axa[0,0], res=res)
        axa[0,0].set_title('Reference DEM')
        axa[0,1].imshow(src_dem_hs, cmap='gray', **kwargs)
        im = axa[0,1].imshow(src_dem_orig, cmap='cpt_rainbow', clim=dem_clim, alpha=0.6, **kwargs)
        pltlib.add_cbar(axa[0,1], im, arr=src_dem_orig, clim=dem_clim, label=None)
        axa[0,1].set_title('Source DEM')
        #axa[0,2].imshow(~static_mask_orig, clim=(0,1), cmap='gray')
        axa[0,2].imshow(~static_mask, clim=(0,1), cmap='gray', **kwargs)
        axa[0,2].set_title('Surfaces for co-registration')
        dz_clim = malib.calcperc_sym(diff_orig_compressed, (5, 95))
        im = axa[1,0].imshow(diff_orig, cmap='RdBu', clim=dz_clim)
        pltlib.add_cbar(axa[1,0], im, arr=diff_orig, clim=dz_clim, label=None)
        axa[1,0].set_title('Elev. Diff. Before (m)')
        im = axa[1,1].imshow(diff_align, cmap='RdBu', clim=dz_clim)
        pltlib.add_cbar(axa[1,1], im, arr=diff_align, clim=dz_clim, label=None)
        axa[1,1].set_title('Elev. Diff. After (m)')

        #tight_dz_clim = (-1.0, 1.0)
        tight_dz_clim = (-2.0, 2.0)
        #tight_dz_clim = (-10.0, 10.0)
        #tight_dz_clim = malib.calcperc_sym(diff_align_filt, (5, 95))
        im = axa[1,2].imshow(diff_align_filt, cmap='RdBu', clim=tight_dz_clim)
        pltlib.add_cbar(axa[1,2], im, arr=diff_align_filt, clim=tight_dz_clim, label=None)
        axa[1,2].set_title('Elev. Diff. After (m)')

        #Tried to insert Nuth fig here
        #ax_nuth.change_geometry(1,2,1)
        #f.axes.append(ax_nuth)

        bins = np.linspace(dz_clim[0], dz_clim[1], 128)
        axa[1,3].hist(diff_orig_compressed, bins, color='g', label='Before', alpha=0.5)
        axa[1,3].hist(diff_align_compressed, bins, color='b', label='After', alpha=0.5)
        axa[1,3].set_xlim(*dz_clim)
        axa[1,3].axvline(0, color='k', linewidth=0.5, linestyle=':')
        axa[1,3].set_xlabel('Elev. Diff. (m)')
        axa[1,3].set_ylabel('Count (px)')
        axa[1,3].set_title("Source - Reference")
        before_str = 'Before\nmed: %0.2f\nnmad: %0.2f' % (diff_orig_stats['med'], diff_orig_stats['nmad'])
        axa[1,3].text(0.05, 0.95, before_str, va='top', color='g', transform=axa[1,3].transAxes, fontsize=8)
        after_str = 'After\nmed: %0.2f\nnmad: %0.2f' % (diff_align_stats['med'], diff_align_stats['nmad'])
        axa[1,3].text(0.65, 0.95, after_str, va='top', color='b', transform=axa[1,3].transAxes, fontsize=8)

        #This is empty
        axa[0,3].axis('off')

        suptitle = '%s\nx: %+0.2fm, y: %+0.2fm, z: %+0.2fm' % (os.path.split(outprefix)[-1], dx_total, dy_total, dz_total)
        f.suptitle(suptitle)
        f.tight_layout()
        plt.subplots_adjust(top=0.90)

        fig_fn = outprefix + '%s_align.png' % xyz_shift_str_cum_fn
        print("Writing out figure: %s" % fig_fn)
        f.savefig(fig_fn, dpi=300)
Esempio n. 6
0
def main():
    parser = getparser()
    #Create dictionary of arguments
    args = vars(parser.parse_args())

    #Want to enable -full when -of is specified, probably a fancy way to do this with argparse
    if args['of']:
        args['full'] = True

    args['imshow_kwargs'] = pltlib.imshow_kwargs

    #Need to implement better extent handling for link and overlay
    #Can use warplib extent parsing
    extent = 'first'
    #extent = 'union'

    #Should accept 'ts' or 'fn' or string here, default is 'ts'
    #Can also accept list for subplots
    title = args['title']

    if args['link']:
        fig = plt.figure(0)
        n_ax = len(args['filelist'])
        src_ds_list = [gdal.Open(fn) for fn in args['filelist']]
        t_srs = geolib.get_ds_srs(src_ds_list[0])
        res_stats = geolib.get_res_stats(src_ds_list, t_srs=t_srs)
        #Use min res
        res = res_stats[0]
        extent = 'intersection'
        extent = geolib.ds_geom_union_extent(src_ds_list, t_srs=t_srs)
        #extent = geolib.ds_geom_intersection_extent(src_ds_list, t_srs=t_srs)
        #print(res, extent)

    for n, fn in enumerate(args['filelist']):
        if not iolib.fn_check(fn):
            print('Unable to open input file: %s' % fn)
            continue

        if title == 'ts':
            ts = timelib.fn_getdatetime_list(fn)

            if ts:
                print("Timestamp list: ", ts)
                if len(ts) == 1:
                    args['title'] = ts[0].date()
                elif len(ts) > 1:
                    args['title'] = "%s to %s" % (ts[0].date(), ts[1].date())
            else:
                print("Unable to extract timestamp")
                args['title'] = None
        elif title == 'fn':
            args['title'] = fn

        #if title is not None:
        #    plt.title(title, fontdict={'fontsize':12})

        #Note: this won't work if img1 has 1 band and img2 has 3 bands
        #Hack for now
        if not args['link']:
            fig = plt.figure(n)
            n_ax = 1

        #fig.set_facecolor('black')
        fig.set_facecolor('white')
        fig.canvas.set_window_title(os.path.split(fn)[1])
        #fig.suptitle(os.path.split(fn)[1], fontsize=10)

        if args['overlay']:
            #Should automatically search for shaded relief with same base fn
            #bg_fn = os.path.splitext(fn)[0]+'_hs_az315.tif'
            #Clip/warp background dataset to match overlay dataset
            src_ds, bg_ds = warplib.memwarp_multi_fn([fn, args['overlay']],
                                                     extent=extent,
                                                     res='max')
            #Want to load up the unique bg array for each input
            args['bg'] = get_bma(bg_ds, 1, args['full'])
        else:
            src_ds = gdal.Open(fn, gdal.GA_ReadOnly)
            if args['link']:
                src_ds = warplib.memwarp_multi([
                    src_ds,
                ],
                                               res=res,
                                               extent=extent,
                                               t_srs=t_srs)[0]

        args['cbar_kwargs'] = pltlib.cbar_kwargs
        if args['no_cbar']:
            args['cbar_kwargs'] = None

        nbands = src_ds.RasterCount
        b = src_ds.GetRasterBand(1)
        dt = gdal.GetDataTypeName(b.DataType)
        #Eventually, check dt of each band
        print("%s (%i bands)" % (fn, nbands))
        #Singleband raster
        if (nbands == 1):
            if args['cmap'] is None:
                #Special case to handle ASP float32 grayscale data
                if '-L_sub' in fn or '-R_sub' in fn:
                    args['cmap'] = 'gray'
                else:
                    if (dt == 'Float64') or (dt == 'Float32') or (dt
                                                                  == 'Int32'):
                        args['cmap'] = 'cpt_rainbow'
                    #This is for WV images
                    elif (dt == 'UInt16'):
                        args['cmap'] = 'gray'
                    elif (dt == 'Byte'):
                        args['cmap'] = 'gray'
                    else:
                        args['cmap'] = 'cpt_rainbow'
                """
                if 'count' in fn:
                    args['clim_perc'] = (0,100)
                    cbar_kwargs['extend'] = 'neither'
                    args['cmap'] = 'cpt_rainbow'
                if 'mask' in fn:
                    args['clim'] = (0, 1)
                    #Could be (0, 255)
                    #args['clim_perc'] = (0,100)
                    #Want absolute clim of 0, then perc of 100
                    cbar_kwargs['extend'] = 'neither'
                    args['cmap'] = 'gray'
                """
            bma = get_bma(src_ds, 1, args['full'])
            if args['invert']:
                bma *= -1
            #Note n+1 here ensures we're assigning subplot correctly here (n is 0-relative, subplot is 1)
            bma_fig(fig, bma, n_subplt=n_ax, subplt=n + 1, ds=src_ds, **args)
        #3-band raster, likely disparity map
        #This doesn't work when alpha band is present
        elif (nbands == 3) and (dt == 'Byte'):
            #For some reason, tifs are vertically flipped
            if (os.path.splitext(fn)[1] == '.tif'):
                args['imshow_kwargs']['origin'] = 'lower'
            #Use gdal dataset here instead of imread(fn)?
            imgplot = plt.imshow(plt.imread(fn), **args['imshow_kwargs'])
            pltlib.hide_ticks(imgplot.axes)
        #Handle the 3-band disparity map case here
        #elif ((dt == 'Float32') or (dt == 'Int32')):
        else:
            if args['cmap'] is None:
                args['cmap'] = 'cpt_rainbow'
            bn = 1
            while bn <= nbands:
                bma = get_bma(src_ds, bn, args['full'])
                bma_fig(fig,
                        bma,
                        n_subplt=nbands,
                        subplt=bn,
                        ds=src_ds,
                        **args)
                bn += 1
        #Want to be better about this else case - lazy for now
        #else:
        #    bma = get_bma(src_ds, 1, args['full'])
        #    bma_fig(fig, bma, **args)

        plt.tight_layout()

        #Write out the file
        #Note: make sure display is local for savefig
        if args['of']:
            outf = str(os.path.splitext(fn)[0]) + '_fig.' + args['of']
            #outf = str(os.path.splitext(fn)[0])+'_'+str(os.path.splitext(args['overlay'])[0])+'_fig.'+args['of']

            #Note: need to account for colorbar (12%) and title - some percentage of axes beyond bma dimensions
            #Should specify minimum text size for output

            max_size = np.array((10.0, 10.0))
            max_dpi = 300.0
            #If both outsize and dpi are specified, don't try to change, just make the figure
            if (args['outsize'] is None) and (args['dpi'] is None):
                args['dpi'] = 150.0

            #Unspecified out figure size for a given dpi
            if (args['outsize'] is None) and (args['dpi'] is not None):
                args['outsize'] = np.array(bma.shape[::-1]) / args['dpi']
                if np.any(np.array(args['outsize']) > max_size):
                    args['outsize'] = max_size
            #Specified output figure size, no specified dpi
            elif (args['outsize'] is not None) and (args['dpi'] is None):
                args['dpi'] = np.min([
                    np.max(
                        np.array(bma.shape[::-1]) / np.array(args['outsize'])),
                    max_dpi
                ])

            print()
            print("Saving output figure:")
            print("Filename: ", outf)
            print("Size (in): ", args['outsize'])
            print("DPI (px/in): ", args['dpi'])
            print("Input dimensions (px): ", bma.shape[::-1])
            print("Output dimensions (px): ",
                  tuple(np.array(args['outsize']) * args['dpi']))
            print()

            fig.set_size_inches(args['outsize'])
            #fig.set_size_inches(54.427, 71.87)
            #fig.set_size_inches(40, 87)
            fig.savefig(outf,
                        dpi=args['dpi'],
                        bbox_inches='tight',
                        pad_inches=0,
                        facecolor=fig.get_facecolor(),
                        edgecolor='none')
            #fig.savefig(outf, dpi=args['dpi'], facecolor=fig.get_facecolor(), edgecolor='none')
    #Show the plot - want to show all at once
    if not args['of']:
        plt.show()
Esempio n. 7
0
def compute_offset(dem1_ds,
                   dem2_ds,
                   dem2_fn,
                   mode='nuth',
                   max_offset_m=100,
                   remove_outliers=True,
                   apply_mask=True):
    #Make sure the input datasets have the same resolution/extent
    #Use projection of source DEM
    dem1_clip_ds, dem2_clip_ds = warplib.memwarp_multi([dem1_ds, dem2_ds], \
            res='max', extent='intersection', t_srs=dem2_ds)

    #Compute size of NCC and SAD search window in pixels
    res = float(geolib.get_res(dem1_clip_ds, square=True)[0])
    max_offset_px = (max_offset_m / res) + 1
    #print(max_offset_px)
    pad = (int(max_offset_px), int(max_offset_px))

    #This will be updated geotransform for dem2
    dem2_gt = np.array(dem2_clip_ds.GetGeoTransform())

    #Load the arrays
    dem1 = iolib.ds_getma(dem1_clip_ds, 1)
    dem2 = iolib.ds_getma(dem2_clip_ds, 1)

    #Compute difference for unaligned inputs
    print("Elevation difference stats for uncorrected input DEMs")
    #Shouldn't need to worry about common mask here, as both inputs are ma
    diff_euler = dem2 - dem1

    static_mask = None
    if apply_mask:
        #Need dem2_fn here to find TOA fn
        static_mask = get_mask(dem2_clip_ds, dem2_fn)
        dem1 = np.ma.array(dem1, mask=static_mask)
        dem2 = np.ma.array(dem2, mask=static_mask)
        diff_euler = np.ma.array(diff_euler, mask=static_mask)
        static_mask = np.ma.getmaskarray(diff_euler)

    if diff_euler.count() == 0:
        sys.exit("No overlapping, unmasked pixels shared between input DEMs")

    #Compute stats for new masked difference map
    diff_stats = malib.print_stats(diff_euler)
    dz = diff_stats[5]

    #This needs further testing
    if remove_outliers:
        med = diff_stats[5]
        nmad = diff_stats[6]
        f = 3
        rmin = med - f * nmad
        rmax = med + f * nmad
        #Use IQR
        #rmin = diff_stats[7]
        #rmax = diff_stats[8]
        diff_euler = np.ma.masked_outside(diff_euler, rmin, rmax)
        #Should also apply to original dem1 and dem2 for sad and ncc

    print("Computing sub-pixel offset between DEMs using mode: %s" % mode)

    #By default, don't create output figure
    fig = None

    #Sum of absolute differences
    if mode == "sad":
        m, int_offset, sp_offset = coreglib.compute_offset_sad(dem1,
                                                               dem2,
                                                               pad=pad)
        #Geotransform has negative y resolution, so don't need negative sign
        #np array is positive down
        #GDAL coordinates are positive up
        dx = sp_offset[1] * dem2_gt[1]
        dy = sp_offset[0] * dem2_gt[5]
    #Normalized cross-correlation of clipped, overlapping areas
    elif mode == "ncc":
        m, int_offset, sp_offset, fig = coreglib.compute_offset_ncc(dem1, dem2, \
                pad=pad, prefilter=False, plot=True)
        dx = sp_offset[1] * dem2_gt[1]
        dy = sp_offset[0] * dem2_gt[5]
    #Nuth and Kaab (2011)
    elif mode == "nuth":
        print("Computing slope and aspect")
        dem1_slope = geolib.gdaldem_mem_ds(dem1_clip_ds,
                                           processing='slope',
                                           returnma=True)
        dem1_aspect = geolib.gdaldem_mem_ds(dem1_clip_ds,
                                            processing='aspect',
                                            returnma=True)
        #Compute relationship between elevation difference, slope and aspect
        fit_param, fig = coreglib.compute_offset_nuth(diff_euler, dem1_slope,
                                                      dem1_aspect)
        #fit_param[0] is magnitude of shift vector
        #fit_param[1] is direction of shift vector
        #fit_param[2] is mean bias divided by tangent of mean slope
        #print(fit_param)
        dx = fit_param[0] * np.sin(np.deg2rad(fit_param[1]))
        dy = fit_param[0] * np.cos(np.deg2rad(fit_param[1]))
        #med_slope = malib.fast_median(dem1_slope)
        #dz = fit_param[2]*np.tan(np.deg2rad(med_slope))
    elif mode == "all":
        print("Not yet implemented")
        #Want to compare all methods, average offsets
        #m, int_offset, sp_offset = coreglib.compute_offset_sad(dem1, dem2)
        #m, int_offset, sp_offset = coreglib.compute_offset_ncc(dem1, dem2)
    #This is a hack to apply the computed median bias correction for shpclip area only
    elif mode == "none":
        print(
            "Skipping alignment, writing out DEM with median bias over static surfaces removed"
        )
        dst_fn = outprefix + '_med%0.1f.tif' % dz
        iolib.writeGTiff(dem2_orig + dz, dst_fn, dem2_ds)
        sys.exit()
    #Note: minus signs here since we are computing dz=(src-ref), but adjusting src
    return -dx, -dy, -dz, static_mask, fig
Esempio n. 8
0
def main2(args):
    #Should check that files exist
    dem1_fn = args.ref_fn
    dem2_fn = args.src_fn
    mode = args.mode
    apply_mask = not args.nomask
    max_offset_m = args.max_offset
    tiltcorr = args.tiltcorr

    #These are tolerances (in meters) to stop iteration
    tol = args.tol
    min_dx = tol
    min_dy = tol
    min_dz = tol

    #Maximum number of iterations
    max_n = 10

    outdir = args.outdir
    if outdir is None:
        outdir = os.path.splitext(dem2_fn)[0] + '_dem_align'

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    outprefix = '%s_%s' % (os.path.splitext(os.path.split(dem2_fn)[-1])[0], \
            os.path.splitext(os.path.split(dem1_fn)[-1])[0])
    outprefix = os.path.join(outdir, outprefix)

    print("\nReference: %s" % dem1_fn)
    print("Source: %s" % dem2_fn)
    print("Mode: %s" % mode)
    print("Output: %s\n" % outprefix)

    dem2_ds = gdal.Open(dem2_fn, gdal.GA_ReadOnly)
    #Often the "ref" DEM is high-res lidar or similar
    #This is a shortcut to resample to match "source" DEM
    dem1_ds = warplib.memwarp_multi_fn([
        dem1_fn,
    ],
                                       res=dem2_ds,
                                       extent=dem2_ds,
                                       t_srs=dem2_ds)[0]
    #dem1_ds = gdal.Open(dem1_fn, gdal.GA_ReadOnly)

    #Create a copy to be updated in place
    dem2_ds_align = iolib.mem_drv.CreateCopy('', dem2_ds, 0)
    #dem2_ds_align = dem2_ds

    #Iteration number
    n = 1
    #Cumulative offsets
    dx_total = 0
    dy_total = 0
    dz_total = 0

    #Now iteratively update geotransform and vertical shift
    while True:
        print("*** Iteration %i ***" % n)
        dx, dy, dz, static_mask, fig = compute_offset(dem1_ds,
                                                      dem2_ds_align,
                                                      dem2_fn,
                                                      mode,
                                                      max_offset_m,
                                                      apply_mask=apply_mask)
        if n == 1:
            static_mask_orig = static_mask
        xyz_shift_str_iter = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx, dy,
                                                                     dz)
        print("Incremental offset: %s" % xyz_shift_str_iter)

        #Should make an animation of this converging
        if fig is not None:
            dst_fn = outprefix + '_%s_iter%i_plot.png' % (mode, n)
            print("Writing offset plot: %s" % dst_fn)
            fig.gca().set_title(xyz_shift_str_iter)
            fig.savefig(dst_fn, dpi=300, bbox_inches='tight', pad_inches=0.1)

        #Apply the horizontal shift to the original dataset
        dem2_ds_align = coreglib.apply_xy_shift(dem2_ds_align,
                                                dx,
                                                dy,
                                                createcopy=False)
        dem2_ds_align = coreglib.apply_z_shift(dem2_ds_align,
                                               dz,
                                               createcopy=False)

        dx_total += dx
        dy_total += dy
        dz_total += dz
        print("Cumulative offset: dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" %
              (dx_total, dy_total, dz_total))

        #Fit plane to residuals and remove
        #Might be better to do this after converging
        """
        if tiltcorr:
            print("Applying planar tilt correction")
            gt = dem2_ds_align.GetGeoTransform()
            #Need to compute diff_euler here
            #Copy portions of compute_offset, create new function 
            vals, resid, coeff = geolib.ma_fitplane(diff_euler_align, gt, perc=(4, 96))
            dem2_ds_align = coreglib.apply_z_shift(dem2_ds_align, -vals, createcopy=False)
        """

        n += 1
        print("\n")
        #If magnitude of shift in all directions is less than tol
        #if n > max_n or (abs(dx) <= min_dx and abs(dy) <= min_dy and abs(dz) <= min_dz):
        #If magnitude of shift is less than tol
        dm = np.sqrt(dx**2 + dy**2 + dz**2)
        if n > max_n or dm < tol:
            break

    #String to append to output filenames
    xyz_shift_str_cum = '_%s_x%+0.2f_y%+0.2f_z%+0.2f' % (mode, dx_total,
                                                         dy_total, dz_total)
    if tiltcorr:
        xyz_shift_str_cum += "_tiltcorr"

    #Compute original elevation difference
    if True:
        dem1_clip_ds, dem2_clip_ds = warplib.memwarp_multi([dem1_ds, dem2_ds], \
                res='max', extent='intersection', t_srs=dem2_ds)
        dem1_orig = iolib.ds_getma(dem1_clip_ds, 1)
        dem2_orig = iolib.ds_getma(dem2_clip_ds, 1)
        diff_euler_orig = dem2_orig - dem1_orig
        if not apply_mask:
            static_mask_orig = np.ma.getmaskarray(diff_euler_orig)
        diff_euler_orig_compressed = diff_euler_orig[~static_mask_orig]
        diff_euler_orig_stats = np.array(
            malib.print_stats(diff_euler_orig_compressed))

        #Write out original eulerian difference map
        print(
            "Writing out original euler difference map for common intersection before alignment"
        )
        dst_fn = outprefix + '_orig_dz_eul.tif'
        iolib.writeGTiff(diff_euler_orig, dst_fn, dem1_clip_ds)

    #Compute final elevation difference
    if True:
        dem1_clip_ds_align, dem2_clip_ds_align = warplib.memwarp_multi([dem1_ds, dem2_ds_align], \
                res='max', extent='intersection', t_srs=dem2_ds_align)
        dem1_align = iolib.ds_getma(dem1_clip_ds_align, 1)
        dem2_align = iolib.ds_getma(dem2_clip_ds_align, 1)
        diff_euler_align = dem2_align - dem1_align
        if not apply_mask:
            static_mask = np.ma.getmaskarray(diff_euler_align)
        diff_euler_align_compressed = diff_euler_align[~static_mask]
        diff_euler_align_stats = np.array(
            malib.print_stats(diff_euler_align_compressed))

        #Fit plane to residuals and remove
        if tiltcorr:
            print("Applying planar tilt correction")
            gt = dem1_clip_ds_align.GetGeoTransform()
            #Need to apply the mask here, so we're only fitting over static surfaces
            #Note that the origmask=False will compute vals for all x and y indices, which is what we want
            vals, resid, coeff = geolib.ma_fitplane(np.ma.array(diff_euler_align, mask=static_mask), \
                    gt, perc=(4, 96), origmask=False)
            #Remove planar offset from difference map
            diff_euler_align -= vals
            #Remove planar offset from aligned dem2
            #Note: dimensions of ds and vals will be different as vals are computed for clipped intersection
            #Recompute planar offset for dem2_ds_align extent
            xgrid, ygrid = geolib.get_xy_grids(dem2_ds_align)
            vals = coeff[0] * xgrid + coeff[1] * ygrid + coeff[2]
            dem2_ds_align = coreglib.apply_z_shift(dem2_ds_align,
                                                   -vals,
                                                   createcopy=False)
            if not apply_mask:
                static_mask = np.ma.getmaskarray(diff_euler_align)
            diff_euler_align_compressed = diff_euler_align[~static_mask]
            diff_euler_align_stats = np.array(
                malib.print_stats(diff_euler_align_compressed))
            print("Creating fitplane plot")
            fig, ax = plt.subplots(figsize=(6, 6))
            fitplane_clim = malib.calcperc(vals, (2, 98))
            im = ax.imshow(vals, cmap='cpt_rainbow', clim=fitplane_clim)
            res = float(geolib.get_res(dem2_clip_ds, square=True)[0])
            pltlib.add_scalebar(ax, res=res)
            pltlib.hide_ticks(ax)
            pltlib.add_cbar(ax, im, label='Fit plane residuals (m)')
            fig.tight_layout()
            dst_fn1 = outprefix + '%s_align_dz_eul_fitplane.png' % xyz_shift_str_cum
            print("Writing out figure: %s" % dst_fn1)
            fig.savefig(dst_fn1, dpi=300, bbox_inches='tight', pad_inches=0.1)

        #Compute higher-order fits?
        #Could also attempt to model along-track and cross-track artifacts

        #Write out aligned eulerian difference map for clipped extent with vertial offset removed
        dst_fn = outprefix + '%s_align_dz_eul.tif' % xyz_shift_str_cum
        print(
            "Writing out aligned difference map with median vertical offset removed"
        )
        iolib.writeGTiff(diff_euler_align, dst_fn, dem1_clip_ds)

    #Write out aligned dem_2 with vertial offset removed
    if True:
        dst_fn2 = outprefix + '%s_align.tif' % xyz_shift_str_cum
        print(
            "Writing out shifted dem2 with median vertical offset removed: %s"
            % dst_fn2)
        #Might be cleaner way to write out MEM ds directly to disk
        dem2_align = iolib.ds_getma(dem2_ds_align)
        iolib.writeGTiff(dem2_align, dst_fn2, dem2_ds_align)
        dem2_ds_align = None

    #Create output plot
    if True:
        print("Creating final plot")
        dem1_hs = geolib.gdaldem_mem_ma(dem1_orig, dem1_clip_ds, returnma=True)
        dem2_hs = geolib.gdaldem_mem_ma(dem2_orig, dem2_clip_ds, returnma=True)
        f, axa = plt.subplots(2, 3, figsize=(11, 8.5))
        for ax in axa.ravel()[:-1]:
            ax.set_facecolor('k')
            pltlib.hide_ticks(ax)
        dem_clim = malib.calcperc(dem1_orig, (2, 98))
        axa[0, 0].imshow(dem1_hs, cmap='gray')
        axa[0, 0].imshow(dem1_orig,
                         cmap='cpt_rainbow',
                         clim=dem_clim,
                         alpha=0.6)
        res = float(geolib.get_res(dem1_clip_ds, square=True)[0])
        pltlib.add_scalebar(axa[0, 0], res=res)
        axa[0, 0].set_title('Reference DEM')
        axa[0, 1].imshow(dem2_hs, cmap='gray')
        axa[0, 1].imshow(dem2_orig,
                         cmap='cpt_rainbow',
                         clim=dem_clim,
                         alpha=0.6)
        axa[0, 1].set_title('Source DEM')
        axa[0, 2].imshow(~static_mask_orig, clim=(0, 1), cmap='gray')
        axa[0, 2].set_title('Surfaces for co-registration')
        dz_clim = malib.calcperc_sym(diff_euler_orig_compressed, (5, 95))
        im = axa[1, 0].imshow(diff_euler_orig, cmap='RdBu', clim=dz_clim)
        pltlib.add_cbar(axa[1, 0], im, label=None)
        axa[1, 0].set_title('Elev. Diff. Before (m)')
        im = axa[1, 1].imshow(diff_euler_align, cmap='RdBu', clim=dz_clim)
        pltlib.add_cbar(axa[1, 1], im, label=None)
        axa[1, 1].set_title('Elev. Diff. After (m)')

        #Tried to insert Nuth fig here
        #ax_nuth.change_geometry(1,2,1)
        #f.axes.append(ax_nuth)

        bins = np.linspace(dz_clim[0], dz_clim[1], 128)
        axa[1, 2].hist(diff_euler_orig_compressed,
                       bins,
                       color='g',
                       label='Before',
                       alpha=0.5)
        axa[1, 2].hist(diff_euler_align_compressed,
                       bins,
                       color='b',
                       label='After',
                       alpha=0.5)
        axa[1, 2].axvline(0, color='k', linewidth=0.5, linestyle=':')
        axa[1, 2].set_xlabel('Elev. Diff. (m)')
        axa[1, 2].set_ylabel('Count (px)')
        axa[1, 2].set_title("Source - Reference")
        #axa[1,2].legend(loc='upper right')
        #before_str = 'Before\nmean: %0.2f\nstd: %0.2f\nmed: %0.2f\nnmad: %0.2f' % tuple(diff_euler_orig_stats[np.array((3,4,5,6))])
        #after_str = 'After\nmean: %0.2f\nstd: %0.2f\nmed: %0.2f\nnmad: %0.2f' % tuple(diff_euler_align_stats[np.array((3,4,5,6))])
        before_str = 'Before\nmed: %0.2f\nnmad: %0.2f' % tuple(
            diff_euler_orig_stats[np.array((5, 6))])
        axa[1, 2].text(0.05,
                       0.95,
                       before_str,
                       va='top',
                       color='g',
                       transform=axa[1, 2].transAxes)
        after_str = 'After\nmed: %0.2f\nnmad: %0.2f' % tuple(
            diff_euler_align_stats[np.array((5, 6))])
        axa[1, 2].text(0.65,
                       0.95,
                       after_str,
                       va='top',
                       color='b',
                       transform=axa[1, 2].transAxes)

        suptitle = '%s\nx: %+0.2fm, y: %+0.2fm, z: %+0.2fm' % (
            os.path.split(outprefix)[-1], dx_total, dy_total, dz_total)
        f.suptitle(suptitle)
        f.tight_layout()
        plt.subplots_adjust(top=0.90)

        dst_fn = outprefix + '%s_align.png' % xyz_shift_str_cum
        print("Writing out figure: %s" % dst_fn)
        f.savefig(dst_fn, dpi=300, bbox_inches='tight', pad_inches=0.1)

        #Removing residual planar tilt can introduce additional slope/aspect dependent offset
        #Want to run another round of main dem_align after removing planar tilt
        if tiltcorr:
            print("\n Rerunning after applying tilt correction \n")
            #Create copy of original arguments
            import copy
            args2 = copy.copy(args)
            #Use aligned, tilt-corrected DEM as input src_fn for second round
            args2.src_fn = dst_fn2
            #Assume we've already corrected most of the tilt during first round (also prevents endless loop)
            args2.tiltcorr = False
            main2(args2)
Esempio n. 9
0
def main():
    parser = getparser()
    args = parser.parse_args()

    #Write out all mask products for the input DEM
    writeall = True

    mask_glaciers = True
    if args.no_icemask:
        mask_glaciers = False

    #Define top-level directory containing DEM
    topdir = os.getcwd()

    #This directory should contain nlcd grid, glacier outlines
    datadir = iolib.get_datadir()

    dem_fn = args.dem_fn
    dem_ds = gdal.Open(dem_fn)
    print(dem_fn)

    #Extract DEM timestamp
    dem_dt = timelib.fn_getdatetime(dem_fn)

    #This will hold datasets for memwarp and output processing
    ds_dict = OrderedDict()
    ds_dict['dem'] = dem_ds

    ds_dict['lulc'] = None
    #lulc_source = get_lulc_source(dem_ds)
    #lulc_ds_full = get_lulc_ds_full(dem_ds)
    #ds_dict['lulc'] = lulc_ds_full

    ds_dict['snodas'] = None
    if args.snodas:
        #Get SNODAS snow depth products for DEM timestamp
        snodas_min_dt = datetime(2003, 9, 30)
        if dem_dt >= snodas_min_dt:
            snodas_outdir = os.path.join(datadir, 'snodas')
            if not os.path.exists(snodas_outdir):
                os.makedirs(snodas_outdir)
            snodas_ds = get_snodas(dem_dt, snodas_outdir)
            if snodas_ds is not None:
                ds_dict['snodas'] = snodas_ds

    ds_dict['modscag'] = None
    #Get MODSCAG products for DEM timestamp
    #These tiles cover CONUS
    #tile_list=('h08v04', 'h09v04', 'h10v04', 'h08v05', 'h09v05')
    if args.modscag:
        modscag_min_dt = datetime(2000, 2, 24)
        if dem_dt < modscag_min_dt:
            print("\nWarning: DEM timestamp (%s) is before earliest MODSCAG timestamp (%s)\nSkipping..." \
                    % (dem_dt, modscag_min_dt))
        else:
            tile_list = get_modis_tile_list(dem_ds)
            print(tile_list)
            pad_days = 7
            modscag_outdir = os.path.join(datadir, 'modscag')
            if not os.path.exists(modscag_outdir):
                os.makedirs(modscag_outdir)
            modscag_fn_list = get_modscag(dem_dt, modscag_outdir, tile_list,
                                          pad_days)
            if modscag_fn_list:
                modscag_ds = proc_modscag(modscag_fn_list,
                                          extent=dem_ds,
                                          t_srs=dem_ds)
                ds_dict['modscag'] = modscag_ds

    #TODO: need to clean this up
    #Better error handling
    #Disabled for now
    #Use reflectance values to estimate snowcover
    ds_dict['toa'] = None
    if args.toa:
        #Use top of atmosphere scaled reflectance values (0-1)
        toa_ds = get_toa_ds(dem_fn)
        ds_dict['toa'] = toa_ds

    #Cull all of the None ds from the ds_dict
    for k, v in ds_dict.items():
        if v is None:
            del ds_dict[k]

    #Warp all masks to DEM extent/res
    #Note: use cubicspline here to avoid artifacts with negative values
    if len(ds_dict) > 0:
        ds_list = warplib.memwarp_multi(ds_dict.values(),
                                        res=dem_ds,
                                        extent=dem_ds,
                                        t_srs=dem_ds,
                                        r='cubicspline')
        #Update
        for n, key in enumerate(ds_dict.keys()):
            ds_dict[key] = ds_list[n]

    #lulc_ds_warp = get_lulc_ds_warp(dem_ds)
    #ds_dict['lulc'] = lulc_ds_warp

    print(' ')
    #Need better handling of ds order based on input ds here

    dem = iolib.ds_getma(ds_dict['dem'])

    #Initialize the mask
    #True (1) represents "valid" unmasked pixel, False (0) represents "invalid" pixel to be masked
    newmask = ~(np.ma.getmaskarray(dem))

    #Basename for output files
    out_fn_base = os.path.splitext(dem_fn)[0]

    #Generate a rockmask
    if args.filter == 'none' or args.bareground_thresh == 0:
        print("Skipping LULC filter")
    else:
        #Note: these now have RGI glacier polygons removed
        #if 'lulc' in ds_dict.keys():
        #We are almost always going to want LULC mask
        rockmask = get_lulc_mask(dem_ds, mask_glaciers=mask_glaciers, \
                filter=args.filter, bareground_thresh=args.bareground_thresh, out_fn=out_fn_base)
        if writeall:
            out_fn = out_fn_base + '_rockmask.tif'
            print("Writing out %s\n" % out_fn)
            iolib.writeGTiff(rockmask, out_fn, src_ds=ds_dict['dem'])
        newmask = np.logical_and(rockmask, newmask)

    if 'snodas' in ds_dict.keys():
        #SNODAS snow depth filter
        snodas_thresh = args.snodas_thresh
        #snow depth values are mm, convert to meters
        snodas_depth = iolib.ds_getma(ds_dict['snodas']) / 1000.
        if snodas_depth.count() > 0:
            print(
                "Applying SNODAS snow depth filter (masking values >= %0.2f m)"
                % snodas_thresh)
            if writeall:
                out_fn = out_fn_base + '_snodas_depth.tif'
                print("Writing out %s" % out_fn)
                iolib.writeGTiff(snodas_depth, out_fn, src_ds=ds_dict['dem'])
            snodas_mask = np.ma.masked_greater(snodas_depth, snodas_thresh)
            #This should be 1 for valid surfaces with no snow, 0 for snowcovered surfaces
            snodas_mask = ~(np.ma.getmaskarray(snodas_mask))
            if writeall:
                out_fn = out_fn_base + '_snodas_mask.tif'
                print("Writing out %s\n" % out_fn)
                iolib.writeGTiff(snodas_mask, out_fn, src_ds=ds_dict['dem'])
            newmask = np.logical_and(snodas_mask, newmask)
        else:
            print(
                "SNODAS grid for input location and timestamp is empty!\nSkipping...\n"
            )

    if 'modscag' in ds_dict.keys():
        #MODSCAG percent snowcover
        modscag_thresh = args.modscag_thresh
        print(
            "Applying MODSCAG fractional snow cover percent filter (masking values >= %0.1f%%)"
            % modscag_thresh)
        modscag_perc = iolib.ds_getma(ds_dict['modscag'])
        if writeall:
            out_fn = out_fn_base + '_modscag_perc.tif'
            print("Writing out %s" % out_fn)
            iolib.writeGTiff(modscag_perc, out_fn, src_ds=ds_dict['dem'])
        modscag_mask = (modscag_perc.filled(0) >= modscag_thresh)
        #This should be 1 for valid surfaces with no snow, 0 for snowcovered surfaces
        modscag_mask = ~(modscag_mask)
        if writeall:
            out_fn = out_fn_base + '_modscag_mask.tif'
            print("Writing out %s\n" % out_fn)
            iolib.writeGTiff(modscag_mask, out_fn, src_ds=ds_dict['dem'])
        newmask = np.logical_and(modscag_mask, newmask)

    if 'toa' in ds_dict.keys():
        #TOA reflectance filter
        #This should be 1 for valid surfaces, 0 for snowcovered surfaces
        toa_mask = get_toa_mask(ds_dict['toa'], args.toa_thresh)
        if writeall:
            out_fn = out_fn_base + '_toamask.tif'
            print("Writing out %s\n" % out_fn)
            iolib.writeGTiff(toa_mask, out_fn, src_ds=ds_dict['dem'])
        newmask = np.logical_and(toa_mask, newmask)

    if False:
        #Filter based on expected snowline
        #Simplest approach uses altitude cutoff
        max_elev = 1500
        newdem = np.ma.masked_greater(dem, max_elev)
        newmask = np.ma.getmaskarray(newdem)

    print(
        "Generating final mask to use for reference surfaces, and applying to input DEM"
    )
    #Now invert to use to create final masked array
    newmask = ~newmask

    #Dilate the mask
    if args.dilate is not None:
        niter = args.dilate
        print("Dilating mask with %i iterations" % niter)
        from scipy import ndimage
        newmask = ~(ndimage.morphology.binary_dilation(~newmask,
                                                       iterations=niter))

    #Check that we have enough pixels, good distribution

    #Apply mask to original DEM - use these surfaces for co-registration
    newdem = np.ma.array(dem, mask=newmask)

    min_validpx_count = 100
    min_validpx_std = 10
    validpx_count = newdem.count()
    validpx_std = newdem.std()
    print("\n%i valid pixels in output ref.tif" % validpx_count)
    print("%0.2f m std output ref.tif\n" % validpx_std)
    #if (validpx_count > min_validpx_count) and (validpx_std > min_validpx_std):
    if (validpx_count > min_validpx_count):
        #Write out final mask
        out_fn = out_fn_base + '_ref.tif'
        print("Writing out %s\n" % out_fn)
        iolib.writeGTiff(newdem, out_fn, src_ds=ds_dict['dem'])
    else:
        print("Not enough valid pixels!")
Esempio n. 10
0
def main():
    parser = getparser()
    args = parser.parse_args()

    #Define top-level directory containing raster
    topdir = os.getcwd()

    #This directory will store SNODAS products
    #Use centralized directory, default is $HOME/data/
    #datadir = iolib.get_datadir()
    datadir = args.datadir
    if not os.path.exists(datadir):
        os.makedirs(datadir)

    fn = args.fn
    ds = gdal.Open(fn)
    print(fn)

    #Extract timestamp from input filename
    dt = timelib.fn_getdatetime(fn)
    #If date is specified, extract timestamp
    if args.date is not None:
        dt = timelib.fn_getdatetime(args.date)

    out_fn_base = os.path.splitext(fn)[0]

    snodas_min_dt = datetime(2003, 9, 30)
    if dt < snodas_min_dt:
        sys.exit("Timestamp is earlier than valid SNODAS model range")

    #snow depth values are mm, convert to meters
    snodas_outdir = os.path.join(datadir, 'snodas')
    if not os.path.exists(snodas_outdir):
        os.makedirs(snodas_outdir)
    snodas_ds_full = get_snodas(dt, snodas_outdir)
    snodas_ds = warplib.memwarp_multi([
        snodas_ds_full,
    ],
                                      res='source',
                                      extent=ds,
                                      t_srs=ds,
                                      r='cubicspline')[0]
    snodas_depth = iolib.ds_getma(snodas_ds) / 1000.

    if snodas_depth.count() > 0:
        #Write out at original resolution
        out_fn = out_fn_base + '_snodas_depth.tif'
        print("Writing out %s" % out_fn)
        iolib.writeGTiff(snodas_depth, out_fn, src_ds=snodas_ds)

        #Warp to match input raster
        #Note: use cubicspline here to avoid artifacts with negative values
        ds_out = warplib.memwarp_multi([
            snodas_ds,
        ],
                                       res=ds,
                                       extent=ds,
                                       t_srs=ds,
                                       r='cubicspline')[0]

        #Write out warped version
        snodas_depth = iolib.ds_getma(ds_out) / 1000.
        out_fn = out_fn_base + '_snodas_depth_warp.tif'
        print("Writing out %s" % out_fn)
        iolib.writeGTiff(snodas_depth, out_fn, src_ds=ds_out)
    else:
        print("SNODAS grid for input location and timestamp is empty!")