コード例 #1
0
ファイル: dem_mask.py プロジェクト: whigg/demcoreg
def mask_nlcd(ds,
              valid='rock+ice+water',
              datadir=None,
              mask_glaciers=True,
              out_fn=None):
    """Generate raster mask for exposed rock in NLCD data
    """
    print("Loading NLCD LULC")
    b = ds.GetRasterBand(1)
    l = b.ReadAsArray()
    print("Filtering NLCD LULC")
    #Original nlcd products have nan as ndv
    #12 - ice
    #31 - rock
    #11 - open water, includes rivers
    #52 - shrub, <5 m tall, >20%
    #42 - evergreeen forest
    #Should use data dictionary here for general masking
    #Using 'rock+ice+water' preserves the most pixels, although could be problematic over areas with lakes
    if valid == 'rock':
        mask = (l == 31)
    elif valid == 'rock+ice':
        mask = np.logical_or((l == 31), (l == 12))
    elif valid == 'rock+ice+water':
        mask = np.logical_or(np.logical_or((l == 31), (l == 12)), (l == 11))
    elif valid == 'not_forest':
        mask = ~(np.logical_or(np.logical_or((l == 41), (l == 42)), (l == 43)))
    elif valid == 'not_forest+not_water':
        mask = ~(np.logical_or(
            np.logical_or(np.logical_or((l == 41), (l == 42)), (l == 43)),
            (l == 11)))
    else:
        print("Invalid mask type")
        mask = None
    #Write out original data
    if out_fn is not None:
        print("Writing out %s\n" % out_fn)
        iolib.writeGTiff(l, out_fn, ds)
    l = None
    if mask_glaciers:
        if datadir is None:
            datadir = iolib.get_datadir()
        #Note: RGI6.0 includes Fountain's 24K polygons, no longer need to maintain two separate shp
        #Use updated 24k glacier outlines
        #glac_shp_fn = os.path.join(datadir, 'conus_glacierpoly_24k/conus_glacierpoly_24k_32610.shp')
        #if not os.path.exists(glac_shp_fn):
        #    glac_shp_fn = None
        glac_shp_fn = None
        icemask = get_icemask(ds, datadir=datadir, glac_shp_fn=glac_shp_fn)
        if icemask is not None:
            mask *= icemask
    return mask
コード例 #2
0
ファイル: dem_mask.py プロジェクト: whigg/demcoreg
def get_icemask(ds, datadir=None, glac_shp_fn=None):
    """Generate glacier polygon raster mask for input Dataset res/extent
    """
    if datadir is None:
        datadir = iolib.get_datadir()
    print("Masking glaciers")
    if glac_shp_fn is None:
        glac_shp_fn = get_glacier_poly(datadir)

    if not os.path.exists(glac_shp_fn):
        print("Unable to locate glacier shp: %s" % glac_shp_fn)
    else:
        print("Found glacier shp: %s" % glac_shp_fn)

    #All of the proj, extent, handling should now occur in shp2array
    icemask = geolib.shp2array(glac_shp_fn, ds)
    return icemask
コード例 #3
0
ファイル: dem_mask.py プロジェクト: whigg/demcoreg
def get_bareground_fn(datadir=None):
    """Calls external shell script `get_bareground.sh` to fetch:

    ~2010 global bare ground, 30 m

    Note: unzipped file size is 64 GB! Original products are uncompressed, and tiles are available globally (including empty data over ocean)

    The shell script will compress all downloaded tiles using lossless LZW compression.

    http://landcover.usgs.gov/glc/BareGroundDescriptionAndDownloads.php
    """
    if datadir is None:
        datadir = iolib.get_datadir()
    bg_fn = os.path.join(datadir, 'bare2010/bare2010.vrt')
    if not os.path.exists(bg_fn):
        cmd = [
            'get_bareground.sh',
        ]
        subprocess.call(cmd)
    return bg_fn
コード例 #4
0
ファイル: dem_mask.py プロジェクト: whigg/demcoreg
def get_glacier_poly(datadir=None):
    """Calls external shell script `get_rgi.sh` to fetch:

    Randolph Glacier Inventory (RGI) glacier outline shapefiles 

    Full RGI database: rgi50.zip is 410 MB

    The shell script will unzip and merge regional shp into single global shp
    
    http://www.glims.org/RGI/
    """
    if datadir is None:
        datadir = iolib.get_datadir()
    #rgi_fn = os.path.join(datadir, 'rgi50/regions/rgi50_merge.shp')
    #Update to rgi60, should have this returned from get_rgi.sh
    rgi_fn = os.path.join(datadir, 'rgi60/regions/rgi60_merge.shp')
    if not os.path.exists(rgi_fn):
        cmd = [
            'get_rgi.sh',
        ]
        subprocess.call(cmd)
    return rgi_fn
コード例 #5
0
ファイル: dem_mask.py プロジェクト: whigg/demcoreg
def get_nlcd_fn(datadir=None):
    """Calls external shell script `get_nlcd.sh` to fetch:

    2011 Land Use Land Cover (nlcd) grids, 30 m
    
    http://www.mrlc.gov/nlcd11_leg.php
    """
    if datadir is None:
        datadir = iolib.get_datadir()
    #This is original filename, which requires ~17 GB
    #nlcd_fn = os.path.join(datadir, 'nlcd_2011_landcover_2011_edition_2014_10_10/nlcd_2011_landcover_2011_edition_2014_10_10.img')
    #get_nlcd.sh now creates a compressed GTiff, which is 1.1 GB
    nlcd_fn = os.path.join(
        datadir,
        'nlcd_2011_landcover_2011_edition_2014_10_10/nlcd_2011_landcover_2011_edition_2014_10_10.tif'
    )
    if not os.path.exists(nlcd_fn):
        cmd = [
            'get_nlcd.sh',
        ]
        subprocess.call(cmd)
    return nlcd_fn
コード例 #6
0
ファイル: dem_mask.py プロジェクト: dshean/demcoreg
#TODO: need to clean up toa handling

import sys
import os
import subprocess
import glob
import argparse

from osgeo import gdal, ogr, osr
import numpy as np

from datetime import datetime, timedelta

from pygeotools.lib import iolib, warplib, geolib, timelib

datadir = iolib.get_datadir()


def get_nlcd_fn(yr=2016):
    """Calls external shell script `get_nlcd.sh` to fetch:

    Land Use Land Cover (nlcd) grids, 30 m
    2011, 2013 or 2016 (default)
    
    http://www.mrlc.gov/nlcd11_leg.php
    """
    #This is original filename, which requires ~17 GB
    #nlcd_fn = os.path.join(datadir, 'nlcd_2011_landcover_2011_edition_2014_10_10/nlcd_2011_landcover_2011_edition_2014_10_10.img')
    #get_nlcd.sh creates a compressed GTiff, which is 1.1 GB
    #nlcd_fn = os.path.join(datadir, 'nlcd_2011_landcover_2011_edition_2014_10_10/nlcd_2011_landcover_2011_edition_2014_10_10.tif')
    #nlcd_fn = os.path.join(datadir, 'NLCD_{0}_Land_Cover_L48_20190424/NLCD_{0}_Land_Cover_L48_20190424.tif'.format(str(yr)))
コード例 #7
0
ファイル: dem_mask.py プロジェクト: whigg/demcoreg
def main():
    parser = getparser()
    args = parser.parse_args()

    #Write out all mask products for the input DEM
    writeall = True

    mask_glaciers = True
    if args.no_icemask:
        mask_glaciers = False

    #Define top-level directory containing DEM
    topdir = os.getcwd()

    #This directory should contain nlcd grid, glacier outlines
    datadir = iolib.get_datadir()

    dem_fn = args.dem_fn
    dem_ds = gdal.Open(dem_fn)
    print(dem_fn)

    #Extract DEM timestamp
    dem_dt = timelib.fn_getdatetime(dem_fn)

    #This will hold datasets for memwarp and output processing
    ds_dict = OrderedDict()
    ds_dict['dem'] = dem_ds

    ds_dict['lulc'] = None
    #lulc_source = get_lulc_source(dem_ds)
    #lulc_ds_full = get_lulc_ds_full(dem_ds)
    #ds_dict['lulc'] = lulc_ds_full

    ds_dict['snodas'] = None
    if args.snodas:
        #Get SNODAS snow depth products for DEM timestamp
        snodas_min_dt = datetime(2003, 9, 30)
        if dem_dt >= snodas_min_dt:
            snodas_outdir = os.path.join(datadir, 'snodas')
            if not os.path.exists(snodas_outdir):
                os.makedirs(snodas_outdir)
            snodas_ds = get_snodas(dem_dt, snodas_outdir)
            if snodas_ds is not None:
                ds_dict['snodas'] = snodas_ds

    ds_dict['modscag'] = None
    #Get MODSCAG products for DEM timestamp
    #These tiles cover CONUS
    #tile_list=('h08v04', 'h09v04', 'h10v04', 'h08v05', 'h09v05')
    if args.modscag:
        modscag_min_dt = datetime(2000, 2, 24)
        if dem_dt < modscag_min_dt:
            print("\nWarning: DEM timestamp (%s) is before earliest MODSCAG timestamp (%s)\nSkipping..." \
                    % (dem_dt, modscag_min_dt))
        else:
            tile_list = get_modis_tile_list(dem_ds)
            print(tile_list)
            pad_days = 7
            modscag_outdir = os.path.join(datadir, 'modscag')
            if not os.path.exists(modscag_outdir):
                os.makedirs(modscag_outdir)
            modscag_fn_list = get_modscag(dem_dt, modscag_outdir, tile_list,
                                          pad_days)
            if modscag_fn_list:
                modscag_ds = proc_modscag(modscag_fn_list,
                                          extent=dem_ds,
                                          t_srs=dem_ds)
                ds_dict['modscag'] = modscag_ds

    #TODO: need to clean this up
    #Better error handling
    #Disabled for now
    #Use reflectance values to estimate snowcover
    ds_dict['toa'] = None
    if args.toa:
        #Use top of atmosphere scaled reflectance values (0-1)
        toa_ds = get_toa_ds(dem_fn)
        ds_dict['toa'] = toa_ds

    #Cull all of the None ds from the ds_dict
    for k, v in ds_dict.items():
        if v is None:
            del ds_dict[k]

    #Warp all masks to DEM extent/res
    #Note: use cubicspline here to avoid artifacts with negative values
    if len(ds_dict) > 0:
        ds_list = warplib.memwarp_multi(ds_dict.values(),
                                        res=dem_ds,
                                        extent=dem_ds,
                                        t_srs=dem_ds,
                                        r='cubicspline')
        #Update
        for n, key in enumerate(ds_dict.keys()):
            ds_dict[key] = ds_list[n]

    #lulc_ds_warp = get_lulc_ds_warp(dem_ds)
    #ds_dict['lulc'] = lulc_ds_warp

    print(' ')
    #Need better handling of ds order based on input ds here

    dem = iolib.ds_getma(ds_dict['dem'])

    #Initialize the mask
    #True (1) represents "valid" unmasked pixel, False (0) represents "invalid" pixel to be masked
    newmask = ~(np.ma.getmaskarray(dem))

    #Basename for output files
    out_fn_base = os.path.splitext(dem_fn)[0]

    #Generate a rockmask
    if args.filter == 'none' or args.bareground_thresh == 0:
        print("Skipping LULC filter")
    else:
        #Note: these now have RGI glacier polygons removed
        #if 'lulc' in ds_dict.keys():
        #We are almost always going to want LULC mask
        rockmask = get_lulc_mask(dem_ds, mask_glaciers=mask_glaciers, \
                filter=args.filter, bareground_thresh=args.bareground_thresh, out_fn=out_fn_base)
        if writeall:
            out_fn = out_fn_base + '_rockmask.tif'
            print("Writing out %s\n" % out_fn)
            iolib.writeGTiff(rockmask, out_fn, src_ds=ds_dict['dem'])
        newmask = np.logical_and(rockmask, newmask)

    if 'snodas' in ds_dict.keys():
        #SNODAS snow depth filter
        snodas_thresh = args.snodas_thresh
        #snow depth values are mm, convert to meters
        snodas_depth = iolib.ds_getma(ds_dict['snodas']) / 1000.
        if snodas_depth.count() > 0:
            print(
                "Applying SNODAS snow depth filter (masking values >= %0.2f m)"
                % snodas_thresh)
            if writeall:
                out_fn = out_fn_base + '_snodas_depth.tif'
                print("Writing out %s" % out_fn)
                iolib.writeGTiff(snodas_depth, out_fn, src_ds=ds_dict['dem'])
            snodas_mask = np.ma.masked_greater(snodas_depth, snodas_thresh)
            #This should be 1 for valid surfaces with no snow, 0 for snowcovered surfaces
            snodas_mask = ~(np.ma.getmaskarray(snodas_mask))
            if writeall:
                out_fn = out_fn_base + '_snodas_mask.tif'
                print("Writing out %s\n" % out_fn)
                iolib.writeGTiff(snodas_mask, out_fn, src_ds=ds_dict['dem'])
            newmask = np.logical_and(snodas_mask, newmask)
        else:
            print(
                "SNODAS grid for input location and timestamp is empty!\nSkipping...\n"
            )

    if 'modscag' in ds_dict.keys():
        #MODSCAG percent snowcover
        modscag_thresh = args.modscag_thresh
        print(
            "Applying MODSCAG fractional snow cover percent filter (masking values >= %0.1f%%)"
            % modscag_thresh)
        modscag_perc = iolib.ds_getma(ds_dict['modscag'])
        if writeall:
            out_fn = out_fn_base + '_modscag_perc.tif'
            print("Writing out %s" % out_fn)
            iolib.writeGTiff(modscag_perc, out_fn, src_ds=ds_dict['dem'])
        modscag_mask = (modscag_perc.filled(0) >= modscag_thresh)
        #This should be 1 for valid surfaces with no snow, 0 for snowcovered surfaces
        modscag_mask = ~(modscag_mask)
        if writeall:
            out_fn = out_fn_base + '_modscag_mask.tif'
            print("Writing out %s\n" % out_fn)
            iolib.writeGTiff(modscag_mask, out_fn, src_ds=ds_dict['dem'])
        newmask = np.logical_and(modscag_mask, newmask)

    if 'toa' in ds_dict.keys():
        #TOA reflectance filter
        #This should be 1 for valid surfaces, 0 for snowcovered surfaces
        toa_mask = get_toa_mask(ds_dict['toa'], args.toa_thresh)
        if writeall:
            out_fn = out_fn_base + '_toamask.tif'
            print("Writing out %s\n" % out_fn)
            iolib.writeGTiff(toa_mask, out_fn, src_ds=ds_dict['dem'])
        newmask = np.logical_and(toa_mask, newmask)

    if False:
        #Filter based on expected snowline
        #Simplest approach uses altitude cutoff
        max_elev = 1500
        newdem = np.ma.masked_greater(dem, max_elev)
        newmask = np.ma.getmaskarray(newdem)

    print(
        "Generating final mask to use for reference surfaces, and applying to input DEM"
    )
    #Now invert to use to create final masked array
    newmask = ~newmask

    #Dilate the mask
    if args.dilate is not None:
        niter = args.dilate
        print("Dilating mask with %i iterations" % niter)
        from scipy import ndimage
        newmask = ~(ndimage.morphology.binary_dilation(~newmask,
                                                       iterations=niter))

    #Check that we have enough pixels, good distribution

    #Apply mask to original DEM - use these surfaces for co-registration
    newdem = np.ma.array(dem, mask=newmask)

    min_validpx_count = 100
    min_validpx_std = 10
    validpx_count = newdem.count()
    validpx_std = newdem.std()
    print("\n%i valid pixels in output ref.tif" % validpx_count)
    print("%0.2f m std output ref.tif\n" % validpx_std)
    #if (validpx_count > min_validpx_count) and (validpx_std > min_validpx_std):
    if (validpx_count > min_validpx_count):
        #Write out final mask
        out_fn = out_fn_base + '_ref.tif'
        print("Writing out %s\n" % out_fn)
        iolib.writeGTiff(newdem, out_fn, src_ds=ds_dict['dem'])
    else:
        print("Not enough valid pixels!")
コード例 #8
0
ファイル: dem_align.py プロジェクト: TristanBlus/demcoreg
def main(args=None):
    parser = getparser()
    args = parser.parse_args()

    # Should check that files exist
    ref_dem_fn = args.ref_fn
    src_dem_fn = args.src_fn

    mode = args.mode
    mask_list = args.mask_list
    max_offset = args.max_offset
    max_dz = args.max_dz
    slope_lim = tuple(args.slope_lim)
    tiltcorr = args.tiltcorr
    polyorder = args.polyorder
    res = args.res

    # Maximum number of iterations
    max_iter = args.max_iter

    # These are tolerances (in meters) to stop iteration
    tol = args.tol
    min_dx = tol
    min_dy = tol
    min_dz = tol

    outdir = args.outdir
    if outdir is None:
        outdir = os.path.splitext(src_dem_fn)[0] + '_dem_align'

    if tiltcorr:
        outdir += '_tiltcorr'
        tiltcorr_done = False
        # Relax tolerance for initial round of co-registration
        # tiltcorr_tol = 0.1
        # if tol < tiltcorr_tol:
        #    tol = tiltcorr_tol

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    outprefix = '%s_%s' % (os.path.splitext(os.path.split(src_dem_fn)[-1])[0],
                           os.path.splitext(os.path.split(ref_dem_fn)[-1])[0])
    outprefix = os.path.join(outdir, outprefix)

    print("\nReference: %s" % ref_dem_fn)
    print("Source: %s" % src_dem_fn)
    print("Mode: %s" % mode)
    print("Output: %s\n" % outprefix)

    src_dem_ds = gdal.Open(src_dem_fn)
    ref_dem_ds = gdal.Open(ref_dem_fn)

    # Get local cartesian coordinate system
    # local_srs = geolib.localtmerc_ds(src_dem_ds)
    # Use original source dataset coordinate system
    # Potentially issues with distortion and xyz/tiltcorr offsets for DEM with large extent
    local_srs = geolib.get_ds_srs(src_dem_ds)
    # local_srs = geolib.get_ds_srs(ref_dem_ds)

    # Resample to common grid
    ref_dem_res = geolib.get_res(ref_dem_ds, t_srs=local_srs, square=False)
    # Create a copy to be updated in place
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    src_dem_res = geolib.get_res(src_dem_ds, t_srs=local_srs, square=False)
    src_dem_ds = None
    # Resample to user-specified resolution
    ref_dem_ds, src_dem_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align],
                                                         extent='intersection', res=args.res, t_srs=local_srs,
                                                         r='cubic')

    res = geolib.get_res(src_dem_ds_align, square=False)
    print("\nReference DEM res: %0.2s" % ref_dem_res)
    print("Source DEM res: %0.2s" % src_dem_res)
    print("Resolution for coreg: %s (%0.2s m)\n" % (args.res, res))

    # Iteration number
    n = 1
    # Cumulative offsets
    dx_total = 0
    dy_total = 0
    dz_total = 0

    # Now iteratively update geotransform and vertical shift
    while True:
        print("*** Iteration %i ***" % n)
        dx, dy, dz, static_mask, fig = compute_offset(ref_dem_ds, src_dem_ds_align, src_dem_fn, mode, max_offset,
                                                      mask_list=mask_list, max_dz=max_dz, slope_lim=slope_lim,
                                                      plot=True)
        xyz_shift_str_iter = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx, dy, dz)
        print("Incremental offset: %s" % xyz_shift_str_iter)

        dx_total += dx
        dy_total += dy
        dz_total += dz

        xyz_shift_str_cum = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx_total, dy_total, dz_total)
        print("Cumulative offset: %s" % xyz_shift_str_cum)
        # String to append to output filenames
        xyz_shift_str_cum_fn = '_%s_x%+0.2f_y%+0.2f_z%+0.2f' % (mode, dx_total, dy_total, dz_total)

        # Should make an animation of this converging
        if n == 1:
            # static_mask_orig = static_mask
            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental: %s\nCumulative: %s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

        # Apply the horizontal shift to the original dataset
        src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx, dy, createcopy=False)
        # Should
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz, createcopy=False)

        n += 1
        print("\n")
        # If magnitude of shift in all directions is less than tol
        # if n > max_iter or (abs(dx) <= min_dx and abs(dy) <= min_dy and abs(dz) <= min_dz):
        # If magnitude of shift is less than tol
        dm = np.sqrt(dx ** 2 + dy ** 2 + dz ** 2)
        dm_total = np.sqrt(dx_total ** 2 + dy_total ** 2 + dz_total ** 2)

        if dm_total > max_offset:
            sys.exit(
                "Total offset exceeded specified max_offset (%0.2f m). Consider increasing -max_offset argument" %
                max_offset)

        # Stop iteration
        if n > max_iter or dm < tol:

            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental:%s\nCumulative:%s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

            # Compute final elevation difference
            if True:
                ref_dem_clip_ds_align, src_dem_clip_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align],
                                                                                     res=res, extent='intersection',
                                                                                     t_srs=local_srs, r='cubic')
                ref_dem_align = iolib.ds_getma(ref_dem_clip_ds_align, 1)
                src_dem_align = iolib.ds_getma(src_dem_clip_ds_align, 1)
                # ref_dem_clip_ds_align = None

                diff_align = src_dem_align - ref_dem_align
                src_dem_align = None
                ref_dem_align = None

                # Get updated, final mask
                mask_glac = get_mask(src_dem_clip_ds_align, mask_list, src_dem_fn, erode=False)
                # mask_glac_erode = get_mask(src_dem_clip_ds_align, mask_list, src_dem_fn, erode=False)
                mask_glac = np.logical_or(np.ma.getmaskarray(diff_align), mask_glac)

                # Final stats, before outlier removal
                diff_align_compressed = diff_align[~mask_glac]
                diff_align_stats = malib.get_stats_dict(diff_align_compressed, full=True)

                # Prepare filtered version for tiltcorr fit

                # 冰川区内大坡度区域
                slope = get_filtered_slope(src_dem_clip_ds_align, slope_lim=(0.01, 35))
                mask_glac_outlier = np.logical_and(mask_glac, np.ma.getmaskarray(slope))

                diff_glac_outlier = np.ma.array(diff_align, mask=~mask_glac_outlier)
                if diff_glac_outlier.count() > 0:
                    diff_align_glac_outlier = outlier_filter(np.ma.array(diff_align, mask=~mask_glac_outlier), f=2,
                                                             max_dz=100)
                    diff_align_glac_outlier[mask_glac_outlier == False] = diff_align[mask_glac_outlier == False]
                else:
                    diff_align_glac_outlier = np.ma.array(diff_align, mask=None)

                diff_align_filt_nonglac = np.ma.array(diff_align_glac_outlier, mask=mask_glac)
                diff_align_filt_compressed = diff_align[~mask_glac]
                diff_align_filt_nonglac = outlier_filter(diff_align_filt_nonglac, f=3, max_dz=max_dz)

                diff_align_filt_stats = malib.get_stats_dict(diff_align_filt_nonglac, full=True)

                diff_align_filt = np.ma.array(diff_align_filt_nonglac, mask=None)
                diff_align_filt_mask = np.ma.getmaskarray(diff_align_filt_nonglac)

                diff_align_filt[mask_glac == True] = diff_align_glac_outlier[mask_glac == True]

            # Fit 2D polynomial to residuals and remove
            # To do: add support for along-track and cross-track artifacts
            if tiltcorr and not tiltcorr_done:
                print("\n************")
                print("Calculating 'tiltcorr' 2D polynomial fit to residuals with order %i" % polyorder)
                print("************\n")
                gt = src_dem_clip_ds_align.GetGeoTransform()

                # Need to apply the mask here, so we're only fitting over static surfaces
                # Note that the origmask=False will compute vals for all x and y indices, which is what we want
                vals, resid, coeff = geolib.ma_fitpoly(diff_align_filt_nonglac, order=polyorder, gt=gt, perc=(2, 98),
                                                       origmask=False)
                # vals, resid, coeff = geolib.ma_fitplane(diff_align_filt, gt, perc=(12.5, 87.5), origmask=False)

                # Should write out coeff or grid with correction

                vals_stats = malib.get_stats_dict(vals)

                # Want to have max_tilt check here
                # max_tilt = 4.0 #m
                # Should do percentage
                # vals.ptp() > max_tilt

                # Note: dimensions of ds and vals will be different as vals are computed for clipped intersection
                # Need to recompute planar offset for full src_dem_ds_align extent and apply
                xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
                valgrid = geolib.polyval2d(xgrid, ygrid, coeff)
                # For results of ma_fitplane
                # valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
                src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)

                # if True:
                #     print("Creating plot of polynomial fit to residuals")
                #     fig, axa = plt.subplots(1,2, figsize=(8, 4))
                #     dz_clim = malib.calcperc_sym(vals, (2, 98))
                #     ax = pltlib.iv(diff_align_filt_nonglac, ax=axa[0], cmap='RdBu', clim=dz_clim, \
                #             label='Residual dz (m)', scalebar=False)
                #     ax = pltlib.iv(valgrid, ax=axa[1], cmap='RdBu', clim=dz_clim, \
                #             label='Polyfit dz (m)', ds=src_dem_ds_align)
                #     #if tiltcorr:
                #         #xyz_shift_str_cum_fn += "_tiltcorr"
                #     tiltcorr_fig_fn = outprefix + '%s_polyfit.png' % xyz_shift_str_cum_fn
                #     print("Writing out figure: %s\n" % tiltcorr_fig_fn)
                #     fig.savefig(tiltcorr_fig_fn, dpi=300)

                print("Applying tilt correction to difference map")
                diff_align -= vals

                # Should iterate until tilts are below some threshold
                # For now, only do one tiltcorr
                tiltcorr_done = True
                # Now use original tolerance, and number of iterations
                tol = args.tol
                max_iter = n + args.max_iter
            else:
                break

    if True:
        # Write out aligned difference map for clipped extent with vertial offset removed
        align_diff_fn = outprefix + '%s_align_diff.tif' % xyz_shift_str_cum_fn
        print("Writing out aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align, align_diff_fn, src_dem_clip_ds_align)

    if True:
        # Write out fitered aligned difference map
        align_diff_filt_fn = outprefix + '%s_align_diff_filt.tif' % xyz_shift_str_cum_fn
        print("Writing out filtered aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align_filt, align_diff_filt_fn, src_dem_clip_ds_align)

    # Extract final center coordinates for intersection
    center_coord_ll = geolib.get_center(src_dem_clip_ds_align, t_srs=geolib.wgs_srs)
    center_coord_xy = geolib.get_center(src_dem_clip_ds_align)
    src_dem_clip_ds_align = None

    # Write out final aligned src_dem
    align_fn = outprefix + '%s_align.tif' % xyz_shift_str_cum_fn
    print("Writing out shifted src_dem with median vertical offset removed: %s" % align_fn)
    # Open original uncorrected dataset at native resolution
    src_dem_ds = gdal.Open(src_dem_fn)
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    # Apply final horizontal and vertial shift to the original dataset
    # Note: potentially issues if we used a different projection during coregistration!
    src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx_total, dy_total, createcopy=False)
    src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz_total, createcopy=False)
    if tiltcorr:
        xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
        valgrid = geolib.polyval2d(xgrid, ygrid, coeff)
        # For results of ma_fitplane
        # valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)
    # Might be cleaner way to write out MEM ds directly to disk
    src_dem_full_align = iolib.ds_getma(src_dem_ds_align)
    iolib.writeGTiff(src_dem_full_align, align_fn, src_dem_ds_align)

    if True:
        # Output final aligned src_dem, masked so only best pixels are preserved
        # Useful if creating a new reference product
        # Can also use apply_mask.py
        print("Applying filter to shifted src_dem")
        align_diff_filt_full_ds = \
            warplib.memwarp_multi_fn([align_diff_filt_fn, ], res=src_dem_ds_align, extent=src_dem_ds_align,
                                     t_srs=src_dem_ds_align)[0]
        align_diff_filt_full = iolib.ds_getma(align_diff_filt_full_ds)
        align_diff_filt_full_ds = None
        align_fn_masked = outprefix + '%s_align_filt.tif' % xyz_shift_str_cum_fn
        iolib.writeGTiff(np.ma.array(src_dem_full_align, mask=np.ma.getmaskarray(align_diff_filt_full)),
                         align_fn_masked, src_dem_ds_align)

    del src_dem_full_align
    del src_dem_ds_align

    # Compute original elevation difference
    if True:
        ref_dem_clip_ds, src_dem_clip_ds = warplib.memwarp_multi([ref_dem_ds, src_dem_ds],
                                                                 res=res, extent='intersection', t_srs=local_srs,
                                                                 r='cubic')
        # src_dem_ds = None
        ref_dem_ds = None
        ref_dem_orig = iolib.ds_getma(ref_dem_clip_ds)
        src_dem_orig = iolib.ds_getma(src_dem_clip_ds)
        # Needed for plotting
        ref_dem_hs = geolib.gdaldem_mem_ds(ref_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        src_dem_hs = geolib.gdaldem_mem_ds(src_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        diff_orig = src_dem_orig - ref_dem_orig
        # Only compute stats over valid surfaces
        static_mask_orig = get_mask(src_dem_clip_ds, mask_list, src_dem_fn)
        # Note: this doesn't include outlier removal or slope mask!
        static_mask_orig = np.logical_or(np.ma.getmaskarray(diff_orig), static_mask_orig)
        # For some reason, ASTER DEM diff have a spike near the 0 bin, could be an issue with masking?
        diff_orig_compressed = diff_orig[~static_mask_orig]
        diff_orig_stats = malib.get_stats_dict(diff_orig_compressed, full=True)

        # Prepare filtered version for comparison
        diff_orig_filt = np.ma.array(diff_orig, mask=static_mask_orig)
        diff_orig_filt = outlier_filter(diff_orig_filt, f=3, max_dz=max_dz)
        # diff_orig_filt = outlier_filter(diff_orig_filt, perc=(12.5, 87.5), max_dz=max_dz)
        slope = get_filtered_slope(src_dem_clip_ds)
        diff_orig_filt = np.ma.array(diff_orig_filt, mask=np.ma.getmaskarray(slope))
        diff_orig_filt_stats = malib.get_stats_dict(diff_orig_filt, full=True)

        # Write out original difference map
        print("Writing out original difference map for common intersection before alignment")
        orig_diff_fn = outprefix + '_orig_diff.tif'
        iolib.writeGTiff(diff_orig, orig_diff_fn, ref_dem_clip_ds)
        # src_dem_clip_ds = None
        ref_dem_clip_ds = None

    if True:
        align_stats_fn = outprefix + '%s_align_stats.json' % xyz_shift_str_cum_fn
        align_stats = {}
        align_stats['src_fn'] = src_dem_fn
        align_stats['ref_fn'] = ref_dem_fn
        align_stats['align_fn'] = align_fn
        align_stats['res'] = {}
        align_stats['res']['src'] = src_dem_res
        align_stats['res']['ref'] = ref_dem_res
        align_stats['res']['coreg'] = res
        align_stats['center_coord'] = {'lon': center_coord_ll[0], 'lat': center_coord_ll[1],
                                       'x': center_coord_xy[0], 'y': center_coord_xy[1]}
        align_stats['shift'] = {'dx': dx_total, 'dy': dy_total, 'dz': dz_total, 'dm': dm_total}
        # This tiltcorr flag gets set to false, need better flag
        if tiltcorr:
            align_stats['tiltcorr'] = {}
            align_stats['tiltcorr']['coeff'] = coeff.tolist()
            align_stats['tiltcorr']['val_stats'] = vals_stats
        align_stats['before'] = diff_orig_stats
        align_stats['before_filt'] = diff_orig_filt_stats
        align_stats['after'] = diff_align_stats
        align_stats['after_filt'] = diff_align_filt_stats

        import json
        with open(align_stats_fn, 'w') as f:
            json.dump(align_stats, f)

    # Create output plot
    if True:
        datadir = iolib.get_datadir()
        shp_fn = os.path.join(datadir, 'gamdam/gamdam_merge_refine_line.shp')
        shp_ds = ogr.Open(shp_fn)
        lyr = shp_ds.GetLayer()
        lyr_srs = lyr.GetSpatialRef()
        shp_extent = geolib.lyr_extent(lyr)
        ds_extent = geolib.ds_extent(src_dem_ds, t_srs=lyr_srs)
        if geolib.extent_compare(shp_extent, ds_extent) is False:
            ext = '_n' + str(int(center_coord_ll[0])) + '_n' + str(int(center_coord_ll[1])).zfill(3)
            # ext = os.path.splitext(os.path.split(ref_dem_fn)[-1])[0][4:13]
            out_fn = os.path.splitext(shp_fn)[0] + ext + '_clip.shp'
            geolib.clip_shp(shp_fn, extent=ds_extent, out_fn=out_fn)
            shp_fn = out_fn

        print("Creating final plot")
        # f, axa = plt.subplots(2, 4, figsize=(11, 8.5))
        f, axa = plt.subplots(2, 4, figsize=(16, 8))
        # for ax in axa.ravel()[:-1]:
        #     ax.set_facecolor('w')
        #     pltlib.hide_ticks(ax)
        dem_clim = malib.calcperc(ref_dem_orig, (2, 98))
        axa[0, 0].imshow(ref_dem_hs, cmap='gray')
        im = axa[0, 0].imshow(ref_dem_orig, cmap='terrain', clim=dem_clim, alpha=0.6)
        pltlib.add_cbar(axa[0, 0], im, arr=ref_dem_orig, clim=dem_clim, label=None)
        pltlib.add_scalebar(axa[0, 0], res=res[0])
        axa[0, 0].set_title('Reference DEM')
        axa[0, 0].set_facecolor('w')
        pltlib.hide_ticks(axa[0, 0])
        # pltlib.shp_overlay(axa[0,0], src_dem_clip_ds, shp_fn, color='k')

        axa[0, 1].imshow(src_dem_hs, cmap='gray')
        im = axa[0, 1].imshow(src_dem_orig, cmap='terrain', clim=dem_clim, alpha=0.6)
        pltlib.add_cbar(axa[0, 1], im, arr=src_dem_orig, clim=dem_clim, label=None)
        axa[0, 1].set_title('Source DEM')
        axa[0, 1].set_facecolor('w')
        pltlib.hide_ticks(axa[0, 1])
        # pltlib.shp_overlay(axa[0,1], src_dem_clip_ds, shp_fn, color='k')
        # axa[0,2].imshow(~static_mask_orig, clim=(0,1), cmap='gray')
        axa[0, 2].imshow(~mask_glac, clim=(0, 1), cmap='gray')
        axa[0, 2].set_title('Surfaces for co-registration')
        axa[0, 2].set_facecolor('w')
        pltlib.hide_ticks(axa[0, 2])

        dz_clim = malib.calcperc_sym(diff_align_filt[mask_glac], (1, 99))
        dz_clim_noglac = malib.calcperc_sym(diff_orig_compressed, (1, 99))

        # dz_clim = (-10, 10)
        # dz_clim_noglac = (-10, 10)

        # axa[0,3].imshow(~static_mask_gla, clim=(0,1), cmap='gray')
        # axa[0,3].set_title('static_mask_gla2')
        # # dz_clim = malib.calcperc_sym(diff_orig_compressed, (1, 99))
        bins = np.linspace(dz_clim_noglac[0], dz_clim_noglac[1], 256)
        # bins = np.linspace(-50, 50, 256)
        axa[0, 3].hist(diff_orig_compressed, bins, color='b', label='Before', alpha=0.5)
        # axa[1,3].hist(diff_align_compressed, bins, color='g', label='After', alpha=0.5)
        axa[0, 3].hist(diff_align_filt_compressed, bins, color='g', label='Filter', alpha=0.5)
        # axa[0, 3].set_xlim(*dz_clim_noglac)
        axa[0, 3].set_xlim(-50, 50)
        axa[0, 3].axvline(0, color='k', linewidth=0.5, linestyle=':')
        axa[0, 3].set_xlabel('Elev. Diff. (m)')
        axa[0, 3].set_ylabel('Count (px)')
        axa[0, 3].set_title("Source - Reference")
        before_str = 'Before\nmed: %0.2f\nnmad: %0.2f' % (diff_orig_stats['med'], diff_orig_stats['nmad'])
        axa[0, 3].text(0.05, 0.95, before_str, va='top', color='b', transform=axa[0, 3].transAxes, fontsize=8)
        # after_str = 'After\nmed: %0.2f\nnmad: %0.2f' % (diff_align_stats['med'], diff_align_stats['nmad'])
        # axa[1,3].text(0.05, 0.65, after_str, va='top', color='g', transform=axa[1,3].transAxes, fontsize=8)
        filt_str = 'Filter\nmed: %0.2f\nnmad: %0.2f' % (diff_align_filt_stats['med'], diff_align_filt_stats['nmad'])
        axa[0, 3].text(0.65, 0.95, filt_str, va='top', color='g', transform=axa[0, 3].transAxes, fontsize=8)

        axa[1, 0].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 0].imshow(diff_orig, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 0], im, arr=diff_orig, clim=dz_clim, label=None)
        axa[1, 0].set_title('Elev. Diff. Before (m)')
        axa[1, 0].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 0])
        # pltlib.shp_overlay(axa[1,0], src_dem_clip_ds, shp_fn, color='k')

        axa[1, 1].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 1].imshow(diff_align, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 1], im, arr=diff_align, clim=dz_clim, label=None)
        axa[1, 1].set_title('Elev. Diff. After (m)')
        axa[1, 1].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 1])
        # pltlib.shp_overlay(axa[1,1], src_dem_clip_ds, shp_fn, color='k')

        # tight_dz_clim = (-1.0, 1.0)
        # tight_dz_clim = (-10.0, 10.0)
        # tight_dz_clim = malib.calcperc_sym(diff_align_filt, (5, 95))
        # im = axa[1,2].imshow(diff_align_filt, cmap='cpt_rainbow', clim=tight_dz_clim)
        # pltlib.add_cbar(axa[1,2], im, arr=diff_align_filt, clim=tight_dz_clim, label=None)
        # axa[1,2].set_title('Elev. Diff. Remove. Outliers (m)')
        axa[1, 2].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 2].imshow(diff_align_filt, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 2], im, arr=diff_align_filt, clim=dz_clim, label=None)
        axa[1, 2].set_title('Elev. Diff. Remove. Outliers (m)')
        axa[1, 2].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 2])
        # pltlib.shp_overlay(axa[1,2], src_dem_clip_ds, shp_fn, color='k')

        tight_dz_clim = (-10, 10)
        axa[1, 3].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 3].imshow(diff_align_filt_nonglac, cmap='cpt_rainbow_r', clim=tight_dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 3], im, arr=diff_align_filt_nonglac, clim=tight_dz_clim, label=None)
        axa[1, 3].set_title('Elev. Diff. NoGlac (m)')
        axa[1, 3].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 3])

        # Tried to insert Nuth fig here
        # ax_nuth.change_geometry(1,2,1)
        # f.axes.append(ax_nuth)

        suptitle = '%s\nx: %+0.2fm, y: %+0.2fm, z: %+0.2fm' % (
            os.path.split(outprefix)[-1], dx_total, dy_total, dz_total)
        f.suptitle(suptitle)
        f.tight_layout()
        plt.subplots_adjust(top=0.90)

        fig_fn = outprefix + '%s_align.png' % xyz_shift_str_cum_fn
        print("Writing out figure: %s" % fig_fn)
        f.savefig(fig_fn, dpi=450)

        if True:
            fig2 = plt.figure(0)
            ax = fig2.add_subplot(1, 1, 1)
            ax.imshow(ref_dem_hs, cmap='gray')
            im = ax.imshow(diff_align_filt, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
            pltlib.add_cbar(ax, im, arr=diff_align_filt, clim=dz_clim, label=None)
            ax.set_title('cLon: %0.1fE    cLat: %0.1fN\n\nElev. Diff. After. Coreg. (m)' % (
                center_coord_ll[1], center_coord_ll[0]))
            ax.set_facecolor('w')
            pltlib.hide_ticks(ax)
            # pltlib.latlon_ticks(ax, lat_in=0.25, lon_in=0.25, in_crs=local_srs.ExportToProj4())
            pltlib.shp_overlay(ax, src_dem_clip_ds, shp_fn, color='k')

            fig2_fn = outprefix + '_align_diff.png'
            fig2.savefig(fig2_fn, dpi=600, bbox_inches='tight', pad_inches=0.1)