Пример #1
0
def makefig(dem, hs, anomaly, ds, title=None):
    f,axa = plt.subplots(1,2,figsize=(10,5))
    #dem_clim = (2300, 4200)
    dem_clim = (1600, 2100)
    hs_clim = (1, 255)
    anomaly_clim = (-15, 15)
    hs_im = axa[0].imshow(hs, vmin=hs_clim[0], vmax=hs_clim[1], cmap='gray')
    dem_im = axa[0].imshow(dem, vmin=dem_clim[0], vmax=dem_clim[1], cmap='cpt_rainbow', alpha=0.5)
    res = 8
    pltlib.add_scalebar(axa[0], res=res)
    pltlib.add_cbar(axa[0], dem_im, label='Elevation (m WGS84)')
    anomaly_im = axa[1].imshow(anomaly, vmin=anomaly_clim[0], vmax=anomaly_clim[1], cmap='RdBu')
    pltlib.add_cbar(axa[1], anomaly_im, label='Elevation Anomaly (m)')
    if shp_fn is not None:
        pltlib.shp_overlay(axa[1], ds, shp_fn, color='darkgreen')
    plt.tight_layout()
    for ax in axa:
        pltlib.hide_ticks(ax)
        ax.set_facecolor('k')
        if title is not None:
            ax.set_title(title)
    return f
Пример #2
0
def bma_fig(fig, bma, cmap='cpt_rainbow', clim=None, clim_perc=(2,98), bg=None, bg_perc=(2,98), n_subplt=1, subplt=1, label=None, title=None, cint=None, alpha=0.5, ticks=False, scalebar=None, ds=None, shp=None, imshow_kwargs={'interpolation':'nearest'}, cbar_kwargs={'extend':'both', 'orientation':'vertical', 'shrink':0.7, 'fraction':0.12, 'pad':0.02}, **kwargs):
    #We don't use the kwargs, just there to save parsing in main
    
    if clim is None:
        clim = malib.calcperc(bma, clim_perc)
        #Deal with masked cases
        if clim[0] == clim[1]:
            if clim[0] > bma.fill_value:
                clim = (bma.fill_value, clim[0])
            else:
                clim = (clim[0], bma.fill_value)
        print "Colorbar limits (%0.1f-%0.1f%%): %0.3f %0.3f" % (clim_perc[0], clim_perc[1], clim[0], clim[1])
    else:
        print "Colorbar limits: %0.3f %0.3f" % (clim[0], clim[1])

    #Link all subplots for zoom/pan
    sharex = sharey = None
    if len(fig.get_axes()) > 0:
        sharex = sharey = fig.get_axes()[0]

    #Hack to catch situations with only 1 subplot, but a subplot number > 1
    if n_subplt == 1:
        subplt = 1

    #One row, multiple columns
    ax = fig.add_subplot(1, n_subplt, subplt, sharex=sharex, sharey=sharey)
    #This occupies the full figure
    #ax = fig.add_axes([0., 0., 1., 1., ])

    #ax.patch.set_facecolor('black')
    ax.patch.set_facecolor('white')

    cmap_name = cmap
    cmap = plt.get_cmap(cmap_name)
    if 'inferno' in cmap_name:
        #Use a gray background
        cmap.set_bad('0.5', alpha=1)
    else:
        #This sets the nodata background to opaque black
        cmap.set_bad('k', alpha=1)
        #cmap.set_bad('w', alpha=1)

    #ax.set_title("Band %i" % subplt, fontsize=10)
    if title is not None:
        ax.set_title(title)

    #If a background image is provided, plot it first
    if bg is not None:
        #Note, 1 is opaque, 0 completely transparent
        #alpha = 0.6
        #bg_perc = (4,96)
        bg_perc = (0.05, 99.95)
        #bg_perc = (1, 99)
        bg_alpha = 1.0
        #bg_alpha = 0.5 
        bg_clim = malib.calcperc(bg, bg_perc)
        bg_cmap_name = 'gray'
        bg_cmap = plt.get_cmap(bg_cmap_name)
        if 'inferno' in cmap_name:
            bg_cmap.set_bad('0.5', alpha=1)
        else:
            bg_cmap.set_bad('k', alpha=1)
        #Set the overlay bad values to completely transparent, otherwise darkens the bg
        cmap.set_bad(alpha=0)
        bgplot = ax.imshow(bg, cmap=bg_cmap, clim=bg_clim, alpha=bg_alpha)
        imgplot = ax.imshow(bma, alpha=alpha, cmap=cmap, clim=clim, **imshow_kwargs)
    else:
        imgplot = ax.imshow(bma, cmap=cmap, clim=clim, **imshow_kwargs)
 
    gt = None
    if ds is not None:
        gt = np.array(ds.GetGeoTransform())
        gt_scale_factor = min(np.array([ds.RasterYSize, ds.RasterXSize])/np.array(bma.shape,dtype=float))
        gt[1] *= gt_scale_factor
        gt[5] *= gt_scale_factor
        ds_srs = geolib.get_ds_srs(ds)
        if ticks:
            scale_ticks(ax, ds)
        else:
            pltlib.hide_ticks(ax)
        xres = geolib.get_res(ds)[0]
    else:
        pltlib.hide_ticks(ax)
    #This forces the black line outlining the image subplot to snap to the actual image dimensions
    ax.set_adjustable('box-forced')

    cbar = True 
    if cbar:
        #Had to turn off the ax=ax for overlay to work
        #cbar = fig.colorbar(imgplot, ax=ax, extend='both', shrink=0.5) 
        #Should set the format based on dtype of input data 
        #cbar_kwargs['format'] = '%i'
        #cbar_kwargs['format'] = '%0.1f'
        #cbar_kwargs['orientation'] = 'horizontal'
        #cbar_kwargs['shrink'] = 0.8

        cbar = pltlib.add_cbar(ax, imgplot, label=label, cbar_kwargs=cbar_kwargs)
   
    #Plot contours every cint interval and update colorbar appropriately
    if cint is not None:
        if bma_c is not None:
            bma_clim = malib.calcperc(bma_c)
            #PIG bed ridge contours
            #bma_clim = (-1300, -300)
            #Jak front shear margin contours
            #bma_clim = (2000, 4000)
            cstart = int(np.floor(bma_clim[0] / cint)) * cint 
            cend = int(np.ceil(bma_clim[1] / cint)) * cint
        else:
            #cstart = int(np.floor(bma.min() / cint)) * cint 
            #cend = int(np.ceil(bma.max() / cint)) * cint
            cstart = int(np.floor(clim[0] / cint)) * cint 
            cend = int(np.ceil(clim[1] / cint)) * cint

        #Turn off dashed negative (beds are below sea level)
        #matplotlib.rcParams['contour.negative_linestyle'] = 'solid'

        clvl = np.arange(cstart, cend+1, cint)
        #contours = ax.contour(bma_c, colors='k', levels=clvl, alpha=0.5)
        contours = ax.contour(bma_c, cmap='gray', linestyle='--', levels=clvl, alpha=1.0)

        #Update the cbar with contour locations
        cbar.add_lines(contours)
        cbar.set_ticks(contours.levels)

    #Plot shape overlay, moved code to pltlib
    if shp is not None:
        pltlib.shp_overlay(ax, ds, shp, gt=gt)

    if scalebar:
        scale_ticks(ax, ds)
        pltlib.add_scalebar(ax, xres)
        if not ticks:
            pltlib.hide_ticks(ax)

    #imgplot.set_cmap(cmap)
    #imgplot.set_clim(clim)
  
    global gbma
    gbma = bma
    global ggt
    ggt = gt

    #Clicking on a subplot will make it active for z-coordinate display
    fig.canvas.mpl_connect('button_press_event', onclick)
    fig.canvas.mpl_connect('axes_enter_event', enter_axis)
    
    #Add support for interactive z-value display 
    ax.format_coord = format_coord
Пример #3
0
def bma_fig(fig,
            bma,
            cmap='cpt_rainbow',
            clim=None,
            clim_perc=(2, 98),
            bg=None,
            bg_perc=(2, 98),
            n_subplt=1,
            subplt=1,
            label=None,
            title=None,
            contour_int=None,
            contour_fn=None,
            alpha=0.5,
            ticks=False,
            scalebar=None,
            ds=None,
            shp=None,
            imshow_kwargs={'interpolation': 'nearest'},
            cbar_kwargs={'orientation': 'vertical'},
            **kwargs):
    #We don't use the kwargs, just there to save parsing in main

    if clim is None:
        clim = pltlib.get_clim(bma, clim_perc=clim_perc)

    print("Colorbar limits: %0.3f %0.3f" % (clim[0], clim[1]))

    #Link all subplots for zoom/pan
    sharex = sharey = None
    if len(fig.get_axes()) > 0:
        sharex = sharey = fig.get_axes()[0]

    #Hack to catch situations with only 1 subplot, but a subplot number > 1
    if n_subplt == 1:
        subplt = 1

    #One row, multiple columns
    ax = fig.add_subplot(1, n_subplt, subplt, sharex=sharex, sharey=sharey)
    #This occupies the full figure
    #ax = fig.add_axes([0., 0., 1., 1., ])

    #ax.patch.set_facecolor('black')
    ax.patch.set_facecolor('white')

    #Set appropriate nodata value color
    cmap_name = cmap
    cmap = pltlib.cmap_setndv(cmap_name)

    #ax.set_title("Band %i" % subplt, fontsize=10)
    if title is not None:
        ax.set_title(title)

    #If a background image is provided, plot it first
    if bg is not None:
        #Note, alpha=1 is opaque, 0 completely transparent
        #alpha = 0.6
        bg_perc = (4, 96)
        bg_alpha = 1.0
        #bg_clim = malib.calcperc(bg, bg_perc)
        bg_clim = (1, 255)
        bg_cmap_name = 'gray'
        bg_cmap = pltlib.cmap_setndv(bg_cmap_name, cmap_name)
        #bg_cmap = plt.get_cmap(bg_cmap_name)
        #if 'inferno' in cmap_name:
        #    bg_cmap.set_bad('0.5', alpha=1)
        #else:
        #    bg_cmap.set_bad('k', alpha=1)
        #Set the overlay bad values to completely transparent, otherwise darkens the bg
        cmap.set_bad(alpha=0)
        bgplot = ax.imshow(bg, cmap=bg_cmap, clim=bg_clim, alpha=bg_alpha)
        imgplot = ax.imshow(bma,
                            alpha=alpha,
                            cmap=cmap,
                            clim=clim,
                            **imshow_kwargs)
    else:
        imgplot = ax.imshow(bma, cmap=cmap, clim=clim, **imshow_kwargs)

    gt = None
    if ds is not None:
        gt = np.array(ds.GetGeoTransform())
        gt_scale_factor = min(
            np.array([ds.RasterYSize, ds.RasterXSize]) /
            np.array(bma.shape, dtype=float))
        gt[1] *= gt_scale_factor
        gt[5] *= gt_scale_factor
        ds_srs = geolib.get_ds_srs(ds)
        if ticks:
            scale_ticks(ax, ds)
        else:
            pltlib.hide_ticks(ax)
        xres = geolib.get_res(ds)[0]
    else:
        pltlib.hide_ticks(ax)
    #This forces the black line outlining the image subplot to snap to the actual image dimensions
    #depreciated in 2.2
    #ax.set_adjustable('box-forced')

    if cbar_kwargs:
        #Should set the format based on dtype of input data
        #cbar_kwargs['format'] = '%i'
        #cbar_kwargs['format'] = '%0.1f'
        #cbar_kwargs['orientation'] = 'horizontal'

        #Determine whether we need to add extend triangles to colorbar
        cbar_kwargs['extend'] = pltlib.get_cbar_extend(bma, clim)

        #Add the colorbar to the axes
        cbar = pltlib.add_cbar(ax,
                               imgplot,
                               label=label,
                               cbar_kwargs=cbar_kwargs)

    #Plot contours every contour_int interval and update colorbar appropriately
    if contour_int is not None:
        if contour_fn is not None:
            contour_bma = iolib.fn_getma(contour_fn)
            contour_bma_clim = malib.calcperc(contour_bma)
        else:
            contour_bma = bma
            contour_bma_clim = clim

        #PIG bed ridge contours
        #bma_clim = (-1300, -300)
        #Jak front shear margin contours
        #bma_clim = (2000, 4000)
        contour_bma_clim = (100, 250)
        cstart = int(np.floor(contour_bma_clim[0] / contour_int)) * contour_int
        cend = int(np.ceil(contour_bma_clim[1] / contour_int)) * contour_int

        #Turn off dashed negative (beds are below sea level)
        #matplotlib.rcParams['contour.negative_linestyle'] = 'solid'

        clvl = np.arange(cstart, cend + 1, contour_int)
        contour_prop = {
            'levels': clvl,
            'linestyle': '-',
            'linewidths': 0.5,
            'alpha': 1.0
        }
        #contours = ax.contour(contour_bma, colors='k', **contour_prop)
        #contour_cmap = 'gray'
        contour_cmap = 'gray_r'
        #This prevents white contours
        contour_cmap_clim = (0, contour_bma_clim[-1])
        contours = ax.contour(contour_bma, cmap=contour_cmap, vmin=contour_cmap_clim[0], \
                vmax=contour_cmap_clim[-1], **contour_prop)

        #Add labels
        ax.clabel(contours,
                  inline=True,
                  inline_spacing=0,
                  fontsize=4,
                  fmt='%i')

        #Update the cbar with contour locations
        #cbar.add_lines(contours)
        #cbar.set_ticks(contours.levels)

    #Plot shape overlay, moved code to pltlib
    if shp is not None:
        pltlib.shp_overlay(ax, ds, shp, gt=gt, color='k')

    if scalebar:
        scale_ticks(ax, ds)
        sb_loc = pltlib.best_scalebar_location(bma)
        #Force scalebar position
        #sb_loc = 'lower right'
        pltlib.add_scalebar(ax, xres, location=sb_loc)
        if not ticks:
            pltlib.hide_ticks(ax)

    #Set up interactive display
    global gbma
    gbma = bma
    global ggt
    ggt = gt

    #Clicking on a subplot will make it active for z-coordinate display
    fig.canvas.mpl_connect('button_press_event', onclick)
    fig.canvas.mpl_connect('axes_enter_event', enter_axis)

    #Add support for interactive z-value display
    ax.format_coord = format_coord
Пример #4
0
def main(args=None):
    parser = getparser()
    args = parser.parse_args()

    # Should check that files exist
    ref_dem_fn = args.ref_fn
    src_dem_fn = args.src_fn

    mode = args.mode
    mask_list = args.mask_list
    max_offset = args.max_offset
    max_dz = args.max_dz
    slope_lim = tuple(args.slope_lim)
    tiltcorr = args.tiltcorr
    polyorder = args.polyorder
    res = args.res

    # Maximum number of iterations
    max_iter = args.max_iter

    # These are tolerances (in meters) to stop iteration
    tol = args.tol
    min_dx = tol
    min_dy = tol
    min_dz = tol

    outdir = args.outdir
    if outdir is None:
        outdir = os.path.splitext(src_dem_fn)[0] + '_dem_align'

    if tiltcorr:
        outdir += '_tiltcorr'
        tiltcorr_done = False
        # Relax tolerance for initial round of co-registration
        # tiltcorr_tol = 0.1
        # if tol < tiltcorr_tol:
        #    tol = tiltcorr_tol

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    outprefix = '%s_%s' % (os.path.splitext(os.path.split(src_dem_fn)[-1])[0],
                           os.path.splitext(os.path.split(ref_dem_fn)[-1])[0])
    outprefix = os.path.join(outdir, outprefix)

    print("\nReference: %s" % ref_dem_fn)
    print("Source: %s" % src_dem_fn)
    print("Mode: %s" % mode)
    print("Output: %s\n" % outprefix)

    src_dem_ds = gdal.Open(src_dem_fn)
    ref_dem_ds = gdal.Open(ref_dem_fn)

    # Get local cartesian coordinate system
    # local_srs = geolib.localtmerc_ds(src_dem_ds)
    # Use original source dataset coordinate system
    # Potentially issues with distortion and xyz/tiltcorr offsets for DEM with large extent
    local_srs = geolib.get_ds_srs(src_dem_ds)
    # local_srs = geolib.get_ds_srs(ref_dem_ds)

    # Resample to common grid
    ref_dem_res = geolib.get_res(ref_dem_ds, t_srs=local_srs, square=False)
    # Create a copy to be updated in place
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    src_dem_res = geolib.get_res(src_dem_ds, t_srs=local_srs, square=False)
    src_dem_ds = None
    # Resample to user-specified resolution
    ref_dem_ds, src_dem_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align],
                                                         extent='intersection', res=args.res, t_srs=local_srs,
                                                         r='cubic')

    res = geolib.get_res(src_dem_ds_align, square=False)
    print("\nReference DEM res: %0.2s" % ref_dem_res)
    print("Source DEM res: %0.2s" % src_dem_res)
    print("Resolution for coreg: %s (%0.2s m)\n" % (args.res, res))

    # Iteration number
    n = 1
    # Cumulative offsets
    dx_total = 0
    dy_total = 0
    dz_total = 0

    # Now iteratively update geotransform and vertical shift
    while True:
        print("*** Iteration %i ***" % n)
        dx, dy, dz, static_mask, fig = compute_offset(ref_dem_ds, src_dem_ds_align, src_dem_fn, mode, max_offset,
                                                      mask_list=mask_list, max_dz=max_dz, slope_lim=slope_lim,
                                                      plot=True)
        xyz_shift_str_iter = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx, dy, dz)
        print("Incremental offset: %s" % xyz_shift_str_iter)

        dx_total += dx
        dy_total += dy
        dz_total += dz

        xyz_shift_str_cum = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx_total, dy_total, dz_total)
        print("Cumulative offset: %s" % xyz_shift_str_cum)
        # String to append to output filenames
        xyz_shift_str_cum_fn = '_%s_x%+0.2f_y%+0.2f_z%+0.2f' % (mode, dx_total, dy_total, dz_total)

        # Should make an animation of this converging
        if n == 1:
            # static_mask_orig = static_mask
            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental: %s\nCumulative: %s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

        # Apply the horizontal shift to the original dataset
        src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx, dy, createcopy=False)
        # Should
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz, createcopy=False)

        n += 1
        print("\n")
        # If magnitude of shift in all directions is less than tol
        # if n > max_iter or (abs(dx) <= min_dx and abs(dy) <= min_dy and abs(dz) <= min_dz):
        # If magnitude of shift is less than tol
        dm = np.sqrt(dx ** 2 + dy ** 2 + dz ** 2)
        dm_total = np.sqrt(dx_total ** 2 + dy_total ** 2 + dz_total ** 2)

        if dm_total > max_offset:
            sys.exit(
                "Total offset exceeded specified max_offset (%0.2f m). Consider increasing -max_offset argument" %
                max_offset)

        # Stop iteration
        if n > max_iter or dm < tol:

            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental:%s\nCumulative:%s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

            # Compute final elevation difference
            if True:
                ref_dem_clip_ds_align, src_dem_clip_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align],
                                                                                     res=res, extent='intersection',
                                                                                     t_srs=local_srs, r='cubic')
                ref_dem_align = iolib.ds_getma(ref_dem_clip_ds_align, 1)
                src_dem_align = iolib.ds_getma(src_dem_clip_ds_align, 1)
                # ref_dem_clip_ds_align = None

                diff_align = src_dem_align - ref_dem_align
                src_dem_align = None
                ref_dem_align = None

                # Get updated, final mask
                mask_glac = get_mask(src_dem_clip_ds_align, mask_list, src_dem_fn, erode=False)
                # mask_glac_erode = get_mask(src_dem_clip_ds_align, mask_list, src_dem_fn, erode=False)
                mask_glac = np.logical_or(np.ma.getmaskarray(diff_align), mask_glac)

                # Final stats, before outlier removal
                diff_align_compressed = diff_align[~mask_glac]
                diff_align_stats = malib.get_stats_dict(diff_align_compressed, full=True)

                # Prepare filtered version for tiltcorr fit

                # 冰川区内大坡度区域
                slope = get_filtered_slope(src_dem_clip_ds_align, slope_lim=(0.01, 35))
                mask_glac_outlier = np.logical_and(mask_glac, np.ma.getmaskarray(slope))

                diff_glac_outlier = np.ma.array(diff_align, mask=~mask_glac_outlier)
                if diff_glac_outlier.count() > 0:
                    diff_align_glac_outlier = outlier_filter(np.ma.array(diff_align, mask=~mask_glac_outlier), f=2,
                                                             max_dz=100)
                    diff_align_glac_outlier[mask_glac_outlier == False] = diff_align[mask_glac_outlier == False]
                else:
                    diff_align_glac_outlier = np.ma.array(diff_align, mask=None)

                diff_align_filt_nonglac = np.ma.array(diff_align_glac_outlier, mask=mask_glac)
                diff_align_filt_compressed = diff_align[~mask_glac]
                diff_align_filt_nonglac = outlier_filter(diff_align_filt_nonglac, f=3, max_dz=max_dz)

                diff_align_filt_stats = malib.get_stats_dict(diff_align_filt_nonglac, full=True)

                diff_align_filt = np.ma.array(diff_align_filt_nonglac, mask=None)
                diff_align_filt_mask = np.ma.getmaskarray(diff_align_filt_nonglac)

                diff_align_filt[mask_glac == True] = diff_align_glac_outlier[mask_glac == True]

            # Fit 2D polynomial to residuals and remove
            # To do: add support for along-track and cross-track artifacts
            if tiltcorr and not tiltcorr_done:
                print("\n************")
                print("Calculating 'tiltcorr' 2D polynomial fit to residuals with order %i" % polyorder)
                print("************\n")
                gt = src_dem_clip_ds_align.GetGeoTransform()

                # Need to apply the mask here, so we're only fitting over static surfaces
                # Note that the origmask=False will compute vals for all x and y indices, which is what we want
                vals, resid, coeff = geolib.ma_fitpoly(diff_align_filt_nonglac, order=polyorder, gt=gt, perc=(2, 98),
                                                       origmask=False)
                # vals, resid, coeff = geolib.ma_fitplane(diff_align_filt, gt, perc=(12.5, 87.5), origmask=False)

                # Should write out coeff or grid with correction

                vals_stats = malib.get_stats_dict(vals)

                # Want to have max_tilt check here
                # max_tilt = 4.0 #m
                # Should do percentage
                # vals.ptp() > max_tilt

                # Note: dimensions of ds and vals will be different as vals are computed for clipped intersection
                # Need to recompute planar offset for full src_dem_ds_align extent and apply
                xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
                valgrid = geolib.polyval2d(xgrid, ygrid, coeff)
                # For results of ma_fitplane
                # valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
                src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)

                # if True:
                #     print("Creating plot of polynomial fit to residuals")
                #     fig, axa = plt.subplots(1,2, figsize=(8, 4))
                #     dz_clim = malib.calcperc_sym(vals, (2, 98))
                #     ax = pltlib.iv(diff_align_filt_nonglac, ax=axa[0], cmap='RdBu', clim=dz_clim, \
                #             label='Residual dz (m)', scalebar=False)
                #     ax = pltlib.iv(valgrid, ax=axa[1], cmap='RdBu', clim=dz_clim, \
                #             label='Polyfit dz (m)', ds=src_dem_ds_align)
                #     #if tiltcorr:
                #         #xyz_shift_str_cum_fn += "_tiltcorr"
                #     tiltcorr_fig_fn = outprefix + '%s_polyfit.png' % xyz_shift_str_cum_fn
                #     print("Writing out figure: %s\n" % tiltcorr_fig_fn)
                #     fig.savefig(tiltcorr_fig_fn, dpi=300)

                print("Applying tilt correction to difference map")
                diff_align -= vals

                # Should iterate until tilts are below some threshold
                # For now, only do one tiltcorr
                tiltcorr_done = True
                # Now use original tolerance, and number of iterations
                tol = args.tol
                max_iter = n + args.max_iter
            else:
                break

    if True:
        # Write out aligned difference map for clipped extent with vertial offset removed
        align_diff_fn = outprefix + '%s_align_diff.tif' % xyz_shift_str_cum_fn
        print("Writing out aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align, align_diff_fn, src_dem_clip_ds_align)

    if True:
        # Write out fitered aligned difference map
        align_diff_filt_fn = outprefix + '%s_align_diff_filt.tif' % xyz_shift_str_cum_fn
        print("Writing out filtered aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align_filt, align_diff_filt_fn, src_dem_clip_ds_align)

    # Extract final center coordinates for intersection
    center_coord_ll = geolib.get_center(src_dem_clip_ds_align, t_srs=geolib.wgs_srs)
    center_coord_xy = geolib.get_center(src_dem_clip_ds_align)
    src_dem_clip_ds_align = None

    # Write out final aligned src_dem
    align_fn = outprefix + '%s_align.tif' % xyz_shift_str_cum_fn
    print("Writing out shifted src_dem with median vertical offset removed: %s" % align_fn)
    # Open original uncorrected dataset at native resolution
    src_dem_ds = gdal.Open(src_dem_fn)
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    # Apply final horizontal and vertial shift to the original dataset
    # Note: potentially issues if we used a different projection during coregistration!
    src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx_total, dy_total, createcopy=False)
    src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz_total, createcopy=False)
    if tiltcorr:
        xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
        valgrid = geolib.polyval2d(xgrid, ygrid, coeff)
        # For results of ma_fitplane
        # valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)
    # Might be cleaner way to write out MEM ds directly to disk
    src_dem_full_align = iolib.ds_getma(src_dem_ds_align)
    iolib.writeGTiff(src_dem_full_align, align_fn, src_dem_ds_align)

    if True:
        # Output final aligned src_dem, masked so only best pixels are preserved
        # Useful if creating a new reference product
        # Can also use apply_mask.py
        print("Applying filter to shifted src_dem")
        align_diff_filt_full_ds = \
            warplib.memwarp_multi_fn([align_diff_filt_fn, ], res=src_dem_ds_align, extent=src_dem_ds_align,
                                     t_srs=src_dem_ds_align)[0]
        align_diff_filt_full = iolib.ds_getma(align_diff_filt_full_ds)
        align_diff_filt_full_ds = None
        align_fn_masked = outprefix + '%s_align_filt.tif' % xyz_shift_str_cum_fn
        iolib.writeGTiff(np.ma.array(src_dem_full_align, mask=np.ma.getmaskarray(align_diff_filt_full)),
                         align_fn_masked, src_dem_ds_align)

    del src_dem_full_align
    del src_dem_ds_align

    # Compute original elevation difference
    if True:
        ref_dem_clip_ds, src_dem_clip_ds = warplib.memwarp_multi([ref_dem_ds, src_dem_ds],
                                                                 res=res, extent='intersection', t_srs=local_srs,
                                                                 r='cubic')
        # src_dem_ds = None
        ref_dem_ds = None
        ref_dem_orig = iolib.ds_getma(ref_dem_clip_ds)
        src_dem_orig = iolib.ds_getma(src_dem_clip_ds)
        # Needed for plotting
        ref_dem_hs = geolib.gdaldem_mem_ds(ref_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        src_dem_hs = geolib.gdaldem_mem_ds(src_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        diff_orig = src_dem_orig - ref_dem_orig
        # Only compute stats over valid surfaces
        static_mask_orig = get_mask(src_dem_clip_ds, mask_list, src_dem_fn)
        # Note: this doesn't include outlier removal or slope mask!
        static_mask_orig = np.logical_or(np.ma.getmaskarray(diff_orig), static_mask_orig)
        # For some reason, ASTER DEM diff have a spike near the 0 bin, could be an issue with masking?
        diff_orig_compressed = diff_orig[~static_mask_orig]
        diff_orig_stats = malib.get_stats_dict(diff_orig_compressed, full=True)

        # Prepare filtered version for comparison
        diff_orig_filt = np.ma.array(diff_orig, mask=static_mask_orig)
        diff_orig_filt = outlier_filter(diff_orig_filt, f=3, max_dz=max_dz)
        # diff_orig_filt = outlier_filter(diff_orig_filt, perc=(12.5, 87.5), max_dz=max_dz)
        slope = get_filtered_slope(src_dem_clip_ds)
        diff_orig_filt = np.ma.array(diff_orig_filt, mask=np.ma.getmaskarray(slope))
        diff_orig_filt_stats = malib.get_stats_dict(diff_orig_filt, full=True)

        # Write out original difference map
        print("Writing out original difference map for common intersection before alignment")
        orig_diff_fn = outprefix + '_orig_diff.tif'
        iolib.writeGTiff(diff_orig, orig_diff_fn, ref_dem_clip_ds)
        # src_dem_clip_ds = None
        ref_dem_clip_ds = None

    if True:
        align_stats_fn = outprefix + '%s_align_stats.json' % xyz_shift_str_cum_fn
        align_stats = {}
        align_stats['src_fn'] = src_dem_fn
        align_stats['ref_fn'] = ref_dem_fn
        align_stats['align_fn'] = align_fn
        align_stats['res'] = {}
        align_stats['res']['src'] = src_dem_res
        align_stats['res']['ref'] = ref_dem_res
        align_stats['res']['coreg'] = res
        align_stats['center_coord'] = {'lon': center_coord_ll[0], 'lat': center_coord_ll[1],
                                       'x': center_coord_xy[0], 'y': center_coord_xy[1]}
        align_stats['shift'] = {'dx': dx_total, 'dy': dy_total, 'dz': dz_total, 'dm': dm_total}
        # This tiltcorr flag gets set to false, need better flag
        if tiltcorr:
            align_stats['tiltcorr'] = {}
            align_stats['tiltcorr']['coeff'] = coeff.tolist()
            align_stats['tiltcorr']['val_stats'] = vals_stats
        align_stats['before'] = diff_orig_stats
        align_stats['before_filt'] = diff_orig_filt_stats
        align_stats['after'] = diff_align_stats
        align_stats['after_filt'] = diff_align_filt_stats

        import json
        with open(align_stats_fn, 'w') as f:
            json.dump(align_stats, f)

    # Create output plot
    if True:
        datadir = iolib.get_datadir()
        shp_fn = os.path.join(datadir, 'gamdam/gamdam_merge_refine_line.shp')
        shp_ds = ogr.Open(shp_fn)
        lyr = shp_ds.GetLayer()
        lyr_srs = lyr.GetSpatialRef()
        shp_extent = geolib.lyr_extent(lyr)
        ds_extent = geolib.ds_extent(src_dem_ds, t_srs=lyr_srs)
        if geolib.extent_compare(shp_extent, ds_extent) is False:
            ext = '_n' + str(int(center_coord_ll[0])) + '_n' + str(int(center_coord_ll[1])).zfill(3)
            # ext = os.path.splitext(os.path.split(ref_dem_fn)[-1])[0][4:13]
            out_fn = os.path.splitext(shp_fn)[0] + ext + '_clip.shp'
            geolib.clip_shp(shp_fn, extent=ds_extent, out_fn=out_fn)
            shp_fn = out_fn

        print("Creating final plot")
        # f, axa = plt.subplots(2, 4, figsize=(11, 8.5))
        f, axa = plt.subplots(2, 4, figsize=(16, 8))
        # for ax in axa.ravel()[:-1]:
        #     ax.set_facecolor('w')
        #     pltlib.hide_ticks(ax)
        dem_clim = malib.calcperc(ref_dem_orig, (2, 98))
        axa[0, 0].imshow(ref_dem_hs, cmap='gray')
        im = axa[0, 0].imshow(ref_dem_orig, cmap='terrain', clim=dem_clim, alpha=0.6)
        pltlib.add_cbar(axa[0, 0], im, arr=ref_dem_orig, clim=dem_clim, label=None)
        pltlib.add_scalebar(axa[0, 0], res=res[0])
        axa[0, 0].set_title('Reference DEM')
        axa[0, 0].set_facecolor('w')
        pltlib.hide_ticks(axa[0, 0])
        # pltlib.shp_overlay(axa[0,0], src_dem_clip_ds, shp_fn, color='k')

        axa[0, 1].imshow(src_dem_hs, cmap='gray')
        im = axa[0, 1].imshow(src_dem_orig, cmap='terrain', clim=dem_clim, alpha=0.6)
        pltlib.add_cbar(axa[0, 1], im, arr=src_dem_orig, clim=dem_clim, label=None)
        axa[0, 1].set_title('Source DEM')
        axa[0, 1].set_facecolor('w')
        pltlib.hide_ticks(axa[0, 1])
        # pltlib.shp_overlay(axa[0,1], src_dem_clip_ds, shp_fn, color='k')
        # axa[0,2].imshow(~static_mask_orig, clim=(0,1), cmap='gray')
        axa[0, 2].imshow(~mask_glac, clim=(0, 1), cmap='gray')
        axa[0, 2].set_title('Surfaces for co-registration')
        axa[0, 2].set_facecolor('w')
        pltlib.hide_ticks(axa[0, 2])

        dz_clim = malib.calcperc_sym(diff_align_filt[mask_glac], (1, 99))
        dz_clim_noglac = malib.calcperc_sym(diff_orig_compressed, (1, 99))

        # dz_clim = (-10, 10)
        # dz_clim_noglac = (-10, 10)

        # axa[0,3].imshow(~static_mask_gla, clim=(0,1), cmap='gray')
        # axa[0,3].set_title('static_mask_gla2')
        # # dz_clim = malib.calcperc_sym(diff_orig_compressed, (1, 99))
        bins = np.linspace(dz_clim_noglac[0], dz_clim_noglac[1], 256)
        # bins = np.linspace(-50, 50, 256)
        axa[0, 3].hist(diff_orig_compressed, bins, color='b', label='Before', alpha=0.5)
        # axa[1,3].hist(diff_align_compressed, bins, color='g', label='After', alpha=0.5)
        axa[0, 3].hist(diff_align_filt_compressed, bins, color='g', label='Filter', alpha=0.5)
        # axa[0, 3].set_xlim(*dz_clim_noglac)
        axa[0, 3].set_xlim(-50, 50)
        axa[0, 3].axvline(0, color='k', linewidth=0.5, linestyle=':')
        axa[0, 3].set_xlabel('Elev. Diff. (m)')
        axa[0, 3].set_ylabel('Count (px)')
        axa[0, 3].set_title("Source - Reference")
        before_str = 'Before\nmed: %0.2f\nnmad: %0.2f' % (diff_orig_stats['med'], diff_orig_stats['nmad'])
        axa[0, 3].text(0.05, 0.95, before_str, va='top', color='b', transform=axa[0, 3].transAxes, fontsize=8)
        # after_str = 'After\nmed: %0.2f\nnmad: %0.2f' % (diff_align_stats['med'], diff_align_stats['nmad'])
        # axa[1,3].text(0.05, 0.65, after_str, va='top', color='g', transform=axa[1,3].transAxes, fontsize=8)
        filt_str = 'Filter\nmed: %0.2f\nnmad: %0.2f' % (diff_align_filt_stats['med'], diff_align_filt_stats['nmad'])
        axa[0, 3].text(0.65, 0.95, filt_str, va='top', color='g', transform=axa[0, 3].transAxes, fontsize=8)

        axa[1, 0].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 0].imshow(diff_orig, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 0], im, arr=diff_orig, clim=dz_clim, label=None)
        axa[1, 0].set_title('Elev. Diff. Before (m)')
        axa[1, 0].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 0])
        # pltlib.shp_overlay(axa[1,0], src_dem_clip_ds, shp_fn, color='k')

        axa[1, 1].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 1].imshow(diff_align, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 1], im, arr=diff_align, clim=dz_clim, label=None)
        axa[1, 1].set_title('Elev. Diff. After (m)')
        axa[1, 1].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 1])
        # pltlib.shp_overlay(axa[1,1], src_dem_clip_ds, shp_fn, color='k')

        # tight_dz_clim = (-1.0, 1.0)
        # tight_dz_clim = (-10.0, 10.0)
        # tight_dz_clim = malib.calcperc_sym(diff_align_filt, (5, 95))
        # im = axa[1,2].imshow(diff_align_filt, cmap='cpt_rainbow', clim=tight_dz_clim)
        # pltlib.add_cbar(axa[1,2], im, arr=diff_align_filt, clim=tight_dz_clim, label=None)
        # axa[1,2].set_title('Elev. Diff. Remove. Outliers (m)')
        axa[1, 2].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 2].imshow(diff_align_filt, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 2], im, arr=diff_align_filt, clim=dz_clim, label=None)
        axa[1, 2].set_title('Elev. Diff. Remove. Outliers (m)')
        axa[1, 2].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 2])
        # pltlib.shp_overlay(axa[1,2], src_dem_clip_ds, shp_fn, color='k')

        tight_dz_clim = (-10, 10)
        axa[1, 3].imshow(ref_dem_hs, cmap='gray')
        im = axa[1, 3].imshow(diff_align_filt_nonglac, cmap='cpt_rainbow_r', clim=tight_dz_clim, alpha=0.6)
        pltlib.add_cbar(axa[1, 3], im, arr=diff_align_filt_nonglac, clim=tight_dz_clim, label=None)
        axa[1, 3].set_title('Elev. Diff. NoGlac (m)')
        axa[1, 3].set_facecolor('w')
        pltlib.hide_ticks(axa[1, 3])

        # Tried to insert Nuth fig here
        # ax_nuth.change_geometry(1,2,1)
        # f.axes.append(ax_nuth)

        suptitle = '%s\nx: %+0.2fm, y: %+0.2fm, z: %+0.2fm' % (
            os.path.split(outprefix)[-1], dx_total, dy_total, dz_total)
        f.suptitle(suptitle)
        f.tight_layout()
        plt.subplots_adjust(top=0.90)

        fig_fn = outprefix + '%s_align.png' % xyz_shift_str_cum_fn
        print("Writing out figure: %s" % fig_fn)
        f.savefig(fig_fn, dpi=450)

        if True:
            fig2 = plt.figure(0)
            ax = fig2.add_subplot(1, 1, 1)
            ax.imshow(ref_dem_hs, cmap='gray')
            im = ax.imshow(diff_align_filt, cmap='cpt_rainbow_r', clim=dz_clim, alpha=0.6)
            pltlib.add_cbar(ax, im, arr=diff_align_filt, clim=dz_clim, label=None)
            ax.set_title('cLon: %0.1fE    cLat: %0.1fN\n\nElev. Diff. After. Coreg. (m)' % (
                center_coord_ll[1], center_coord_ll[0]))
            ax.set_facecolor('w')
            pltlib.hide_ticks(ax)
            # pltlib.latlon_ticks(ax, lat_in=0.25, lon_in=0.25, in_crs=local_srs.ExportToProj4())
            pltlib.shp_overlay(ax, src_dem_clip_ds, shp_fn, color='k')

            fig2_fn = outprefix + '_align_diff.png'
            fig2.savefig(fig2_fn, dpi=600, bbox_inches='tight', pad_inches=0.1)