Example #1
0
def compute_cam_px_reproj_err_stats(content_line, idx):
    """
    Compute discriptive pixel reprojection error stats for all points in a given camera and return as dict
    Parameters
    -----------
    content_line: list
        list of str, each string containing 1 line contents of run-final_residuals_no_loss_function_raw_pixels.txt
    idx: np.array
        point indices for which reprojection error needs to be read
    Returns
    -----------
    stats: dictionary
        cumulative descriptive stats for all pixels viewed from a given camera
    """
    px, py = read_px_error(content_line, idx)
    stats = malib.get_stats_dict(np.sqrt(px**2 + py**2), full=True)
    return stats
Example #2
0
def main():
    parser = getparser()
    args = parser.parse_args()
    refdem = args.refdem
    srcdem = args.srcdem
    outfolder = '{}__{}_comparison_stats'.format(
        os.path.splitext(os.path.basename(refdem))[0],
        os.path.splitext(os.path.basename(srcdem))[0])
    header_str = '{}__{}'.format(
        os.path.splitext(os.path.basename(refdem))[0],
        os.path.splitext(os.path.basename(srcdem))[0])
    if not os.path.exists(outfolder):
        os.makedirs(outfolder)
    if args.local_ortho == 1:
        temp_ds = warplib.memwarp_multi_fn([refdem, srcdem])[0]
        bbox = geolib.ds_extent(temp_ds)
        geo_crs = temp_ds.GetProjection()
        print('Bounding box lon_lat is{}'.format(bbox))
        bound_poly = Polygon([[bbox[0], bbox[3]], [bbox[2], bbox[3]],
                              [bbox[2], bbox[1]], [bbox[0], bbox[1]]])
        bound_shp = gpd.GeoDataFrame(index=[0],
                                     geometry=[bound_poly],
                                     crs=geo_crs)
        bound_centroid = bound_shp.centroid
        cx = bound_centroid.x.values[0]
        cy = bound_centroid.y.values[0]
        pad = np.ptp([bbox[3], bbox[1]]) / 6.0
        lat_1 = bbox[1] + pad
        lat_2 = bbox[3] - pad
        local_ortho = "+proj=ortho +lat_1={} +lat_2={} +lat_0={} +lon_0={} +x_0=0 +y_0=0 +ellps=WGS84 +datum=WGS84 +units=m +no_defs".format(
            lat_1, lat_2, cy, cx)
        logging.info('Local Ortho projection is {}'.format(local_ortho))
        t_srs = local_ortho
    else:
        t_srs = 'first'
    # this step performs the desired warping operation
    ds_list = warplib.memwarp_multi_fn([refdem, srcdem],
                                       res=args.comparison_res,
                                       t_srs=t_srs)
    refma = iolib.ds_getma(ds_list[0])
    srcma = iolib.ds_getma(ds_list[1])
    init_diff = refma - srcma
    init_stats = malib.get_stats_dict(init_diff)
    print("Original descriptive statistics {}".format(init_stats))
    init_diff_json_fn = os.path.join(
        outfolder, '{}_precoreg_descriptive_stats.json'.format(header_str))
    init_diff_json = json.dumps(init_stats)

    with open(init_diff_json_fn, 'w') as f:
        f.write(init_diff_json)
    logging.info("Saved initial stats at {}".format(init_diff_json))
    refslope = gdaldem(ds_list[0])
    # stats for elevation difference vs reference DEM elevation
    elev_bin, diff_mean, diff_median, diff_std, diff_perc = cummulative_profile(
        refma, init_diff, args.elev_bin_width)
    # stats for elevation difference vs reference DEM slope
    slope_bin, diff_mean_s, diff_median_s, diff_std_s, diff_perc_s = cummulative_profile(
        refslope, init_diff, args.slope_bin_width)
    f, ax = plt.subplots(1, 2, figsize=(10, 4))
    im = ax[0].scatter(elev_bin, diff_mean, c=diff_perc, cmap='inferno')
    ax[0].set_xlabel('Elevation (m)')
    divider = make_axes_locatable(ax[0])
    cax = divider.append_axes('right', size='2.5%', pad=0.05)
    plt.colorbar(im,
                 cax=cax,
                 orientation='vertical',
                 label='pixel count percentage')
    im2 = ax[1].scatter(slope_bin, diff_mean_s, c=diff_perc_s, cmap='inferno')
    ax[1].set_xlabel('Slope (degrees)')
    divider = make_axes_locatable(ax[1])
    cax = divider.append_axes('right', size='2.5%', pad=0.05)
    plt.colorbar(im2,
                 cax=cax,
                 orientation='vertical',
                 label='pixel count percentage')

    for axa in ax.ravel():
        axa.axhline(y=0, c='k')
        axa.set_ylabel('Elevation Difference (m)')
    plt.tight_layout()
    precoreg_plot = os.path.join(outfolder,
                                 header_str + '_precoreg_binned_plot.png')
    f.savefig(precoreg_plot, dpi=300, bbox_inches='tight', pad_inches=0.1)
    logging.info("Saved binned plot at {}".format(precoreg_plot))
    if args.coreg == 1:
        logging.info("will attempt coregisteration")
        if args.local_ortho == 1:
            ref_local_ortho = os.path.splitext(refdem)[0] + '_local_ortho.tif'
            src_local_ortho = os.path.splitext(srcdem)[0] + '_local_ortho.tif'
            # coregisteration works best at mean resolution
            # we will rewarp if the initial args.res was not mean
            if args.comparison_res != 'mean':
                ds_list = warplib.memwarp_multi_fn([refdem, srcdem],
                                                   res='mean',
                                                   t_srs=t_srs)
                refma = iolib.ds_getma(ds_list[0])
                srcma = iolib.ds_getma(ds_list[1])
            iolib.writeGTiff(refma, ref_local_ortho, ds_list[0])
            iolib.writeGTiff(srcma, src_local_ortho, ds_list[1])
            coreg_ref = ref_local_ortho
            src_ref = src_local_ortho
        else:
            coreg_ref = refdem
            src_ref = srcdem
        demcoreg_dir = os.path.join(outfolder, 'coreg_results')
        align_opts = [
            '-mode', 'nuth', '-max_iter', '12', '-max_offset', '400',
            '-outdir', demcoreg_dir
        ]
        align_args = [coreg_ref, src_ref]
        align_cmd = ['dem_align.py'] + align_opts + align_args
        subprocess.call(align_cmd)
        #ah final round of warping and stats calculation
        try:
            srcdem_align = glob.glob(os.path.join(demcoreg_dir,
                                                  '*align.tif'))[0]
            logging.info(
                "Attempting stats calculation for aligned DEM {}".format(
                    srcdem_align))
            ds_list = warplib.memwarp_multi_fn([args.refdem, srcdem_align],
                                               res=args.comparison_res,
                                               t_srs=t_srs)
            refma = iolib.ds_getma(ds_list[0])
            srcma = iolib.ds_getma(ds_list[1])
            # this is creepy, but I am recycling variable names to save on memory
            init_diff = refma - srcma
            init_stats = malib.get_stats_dict(init_diff)
            print("Final descriptive statistics {}".format(init_stats))
            init_diff_json_fn = os.path.join(
                outfolder,
                '{}_postcoreg_descriptive_stats.json'.format(header_str))
            init_diff_json = json.dumps(init_stats)

            with open(init_diff_json_fn, 'w') as f:
                f.write(init_diff_json)
            logging.info("Saved final stats at {}".format(init_diff_json))
            refslope = gdaldem(ds_list[0])
            # stats for elevation difference vs reference DEM elevation
            elev_bin, diff_mean, diff_median, diff_std, diff_perc = cummulative_profile(
                refma, init_diff, args.elev_bin_width)
            # stats for elevation difference vs reference DEM slope
            slope_bin, diff_mean_s, diff_median_s, diff_std_s, diff_perc_s = cummulative_profile(
                refslope, init_diff, args.slope_bin_width)
            f, ax = plt.subplots(1, 2, figsize=(10, 4))
            im = ax[0].scatter(elev_bin,
                               diff_mean,
                               c=diff_perc,
                               cmap='inferno')
            ax[0].set_xlabel('Elevation (m)')
            divider = make_axes_locatable(ax[0])
            cax = divider.append_axes('right', size='2.5%', pad=0.05)
            plt.colorbar(im,
                         cax=cax,
                         orientation='vertical',
                         label='pixel count percentage')
            im2 = ax[1].scatter(slope_bin,
                                diff_mean_s,
                                c=diff_perc_s,
                                cmap='inferno')
            ax[1].set_xlabel('Slope (degrees)')
            divider = make_axes_locatable(ax[1])
            cax = divider.append_axes('right', size='2.5%', pad=0.05)
            plt.colorbar(im2,
                         cax=cax,
                         orientation='vertical',
                         label='pixel count percentage')

            for axa in ax.ravel():
                axa.axhline(y=0, c='k')
                axa.set_ylabel('Elevation Difference (m)')
            plt.tight_layout()
            precoreg_plot = os.path.join(
                outfolder, header_str + '_postcoreg_binned_plot.png')
            f.savefig(precoreg_plot,
                      dpi=300,
                      bbox_inches='tight',
                      pad_inches=0.1)
        except:
            logging.info(
                "Failed to compute post coreg stats, see corresponding job log"
            )
        logging.info("Script is complete !")
Example #3
0
def main():
    parser = getparser()
    args = parser.parse_args()
    mode = args.mode
    session = args.t
    img_folder = os.path.abspath(args.img)
    outdir = os.path.abspath(args.outdir)
    if not os.path.exists(outdir):
        try:
            os.makedir(outdir)
        except:
            os.makedirs(outdir)
    if mode == 'video':
        sampling = args.video_sampling_mode
        frame_index = skysat.parse_frame_index(args.frame_index, True)
        product_level = 'l1a'
        num_samples = len(frame_index)
        frames = frame_index.name.values
        sampler = args.sampler
        outdf = os.path.join(outdir, os.path.basename(args.frame_index))
        if sampling == 'sampling_interval':
            print(
                "Hardcoded sampling interval results in frame exclusion at the end of the video sequence based on step size, better to chose the num_images mode and the program will equally distribute accordingly"
            )
            idx = np.arange(0, num_samples, sampler)
            outdf = '{}_sampling_inteval_{}.csv'.format(
                os.path.splitext(outdf)[0], sampler)
        else:
            print("Sampling {} from {} of the input video sequence".format(
                sampler, num_samples))
            idx = np.linspace(0, num_samples - 1, sampler, dtype=int)
            outdf = '{}_sampling_inteval_aprox{}.csv'.format(
                os.path.splitext(outdf)[0], idx[1] - idx[0])
        sub_sampled_frames = frames[idx]
        sub_df = frame_index[frame_index['name'].isin(
            list(sub_sampled_frames))]
        sub_df.to_csv(outdf, sep=',', index=False)
        #this is camera/gcp initialisation
        n = len(sub_sampled_frames)
        img_list = [
            glob.glob(os.path.join(img_folder, '{}*.tiff'.format(frame)))[0]
            for frame in sub_sampled_frames
        ]
        pitch = [1] * n
        out_fn = [
            os.path.join(outdir, '{}_frame_idx.tsai'.format(frame))
            for frame in sub_sampled_frames
        ]
        out_gcp = [
            os.path.join(outdir, '{}_frame_idx.gcp'.format(frame))
            for frame in sub_sampled_frames
        ]
        frame_index = [args.frame_index] * n
        camera = [None] * n
        gcp_factor = 4

    elif mode == 'triplet':
        df = pd.read_pickle(args.overlap_pkl)
        img_list = list(
            np.unique(np.array(list(df.img1.values) + list(df.img2.values))))
        img_list = [
            os.path.splitext(os.path.basename(img))[0] for img in img_list
        ]
        cam_list = [
            glob.glob(os.path.join(img_folder, '{}*.tif'.format(img)))[0]
            for img in img_list
        ]
        n = len(img_list)
        if args.product_level == 'l1b':
            pitch = [0.8] * n
        else:
            pitch = [1.0] * n
        out_fn = [
            os.path.join(outdir, '{}_rpc.tsai'.format(frame))
            for frame in img_list
        ]
        out_gcp = [
            os.path.join(outdir, '{}_rpc.gcp'.format(frame))
            for frame in img_list
        ]
        camera = cam_list
        frame_index = [None] * n
        img_list = cam_list
        gcp_factor = 8
    fl = [553846.153846] * n
    cx = [1280] * n
    cy = [540] * n
    dem = args.dem
    ht_datum = [malib.get_stats_dict(iolib.fn_getma(dem))['median']
                ] * n  # use this value for height where DEM has no-data
    gcp_std = [1] * n
    datum = ['WGS84'] * n
    refdem = [dem] * n
    n_proc = 30
    #n_proc = cpu_count()
    cam_gen_log = p_map(asp.cam_gen,
                        img_list,
                        fl,
                        cx,
                        cy,
                        pitch,
                        ht_datum,
                        gcp_std,
                        out_fn,
                        out_gcp,
                        datum,
                        refdem,
                        camera,
                        frame_index,
                        num_cpus=n_proc)
    print("writing gcp with basename removed")
    # count expexted gcp
    print(f"Total expected GCP {gcp_factor*n}")
    asp.clean_gcp(out_gcp, outdir)
    # saving subprocess consolidated log file
    from datetime import datetime
    now = datetime.now()
    log_fn = os.path.join(outdir, 'camgen_{}.log'.format(now))
    print("saving subprocess camgen log at {}".format(log_fn))
    with open(log_fn, 'w') as f:
        for log in cam_gen_log:
            f.write(log)
    print("Script is complete !")
Example #4
0
def main(argv=None):
    parser = getparser()
    args = parser.parse_args()

    #Should check that files exist
    ref_dem_fn = args.ref_fn
    src_dem_fn = args.src_fn

    mode = args.mode
    mask_list = args.mask_list
    max_offset = args.max_offset
    max_dz = args.max_dz
    slope_lim = tuple(args.slope_lim)
    tiltcorr = args.tiltcorr
    polyorder = args.polyorder
    res = args.res

    #Maximum number of iterations
    max_iter = args.max_iter

    #These are tolerances (in meters) to stop iteration
    tol = args.tol
    min_dx = tol
    min_dy = tol
    min_dz = tol

    outdir = args.outdir
    if outdir is None:
        outdir = os.path.splitext(src_dem_fn)[0] + '_dem_align'

    if tiltcorr:
        outdir += '_tiltcorr'
        tiltcorr_done = False
        #Relax tolerance for initial round of co-registration
        #tiltcorr_tol = 0.1
        #if tol < tiltcorr_tol:
        #    tol = tiltcorr_tol

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    outprefix = '%s_%s' % (os.path.splitext(os.path.split(src_dem_fn)[-1])[0], \
            os.path.splitext(os.path.split(ref_dem_fn)[-1])[0])
    outprefix = os.path.join(outdir, outprefix)

    print("\nReference: %s" % ref_dem_fn)
    print("Source: %s" % src_dem_fn)
    print("Mode: %s" % mode)
    print("Output: %s\n" % outprefix)

    src_dem_ds = gdal.Open(src_dem_fn)
    ref_dem_ds = gdal.Open(ref_dem_fn)

    #Get local cartesian coordinate system
    #local_srs = geolib.localtmerc_ds(src_dem_ds)
    #Use original source dataset coordinate system
    #Potentially issues with distortion and xyz/tiltcorr offsets for DEM with large extent
    local_srs = geolib.get_ds_srs(src_dem_ds)
    #local_srs = geolib.get_ds_srs(ref_dem_ds)

    #Resample to common grid
    ref_dem_res = float(geolib.get_res(ref_dem_ds, t_srs=local_srs, square=True)[0])
    #Create a copy to be updated in place
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    src_dem_res = float(geolib.get_res(src_dem_ds, t_srs=local_srs, square=True)[0])
    src_dem_ds = None
    #Resample to user-specified resolution
    ref_dem_ds, src_dem_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align], \
            extent='intersection', res=args.res, t_srs=local_srs, r='cubic')

    res = float(geolib.get_res(src_dem_ds_align, square=True)[0])
    print("\nReference DEM res: %0.2f" % ref_dem_res)
    print("Source DEM res: %0.2f" % src_dem_res)
    print("Resolution for coreg: %s (%0.2f m)\n" % (args.res, res))

    #Iteration number
    n = 1
    #Cumulative offsets
    dx_total = 0
    dy_total = 0
    dz_total = 0

    #Now iteratively update geotransform and vertical shift
    while True:
        print("*** Iteration %i ***" % n)
        dx, dy, dz, static_mask, fig = compute_offset(ref_dem_ds, src_dem_ds_align, src_dem_fn, mode, max_offset, \
                mask_list=mask_list, max_dz=max_dz, slope_lim=slope_lim, plot=True)
        xyz_shift_str_iter = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx, dy, dz)
        print("Incremental offset: %s" % xyz_shift_str_iter)

        dx_total += dx
        dy_total += dy
        dz_total += dz

        xyz_shift_str_cum = "dx=%+0.2fm, dy=%+0.2fm, dz=%+0.2fm" % (dx_total, dy_total, dz_total)
        print("Cumulative offset: %s" % xyz_shift_str_cum)
        #String to append to output filenames
        xyz_shift_str_cum_fn = '_%s_x%+0.2f_y%+0.2f_z%+0.2f' % (mode, dx_total, dy_total, dz_total)

        #Should make an animation of this converging
        if n == 1: 
            #static_mask_orig = static_mask
            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental: %s\nCumulative: %s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

        #Apply the horizontal shift to the original dataset
        src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx, dy, createcopy=False)
        #Should 
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz, createcopy=False)

        n += 1
        print("\n")
        #If magnitude of shift in all directions is less than tol
        #if n > max_iter or (abs(dx) <= min_dx and abs(dy) <= min_dy and abs(dz) <= min_dz):
        #If magnitude of shift is less than tol
        dm = np.sqrt(dx**2 + dy**2 + dz**2)
        dm_total = np.sqrt(dx_total**2 + dy_total**2 + dz_total**2)

        if dm_total > max_offset:
            sys.exit("Total offset exceeded specified max_offset (%0.2f m). Consider increasing -max_offset argument" % max_offset)

        #Stop iteration
        if n > max_iter or dm < tol:

            if fig is not None:
                dst_fn = outprefix + '_%s_iter%02i_plot.png' % (mode, n)
                print("Writing offset plot: %s" % dst_fn)
                fig.gca().set_title("Incremental:%s\nCumulative:%s" % (xyz_shift_str_iter, xyz_shift_str_cum))
                fig.savefig(dst_fn, dpi=300)

            #Compute final elevation difference
            if True:
                ref_dem_clip_ds_align, src_dem_clip_ds_align = warplib.memwarp_multi([ref_dem_ds, src_dem_ds_align], \
                        res=res, extent='intersection', t_srs=local_srs, r='cubic')
                ref_dem_align = iolib.ds_getma(ref_dem_clip_ds_align, 1)
                src_dem_align = iolib.ds_getma(src_dem_clip_ds_align, 1)
                ref_dem_clip_ds_align = None

                diff_align = src_dem_align - ref_dem_align
                src_dem_align = None
                ref_dem_align = None

                #Get updated, final mask
                static_mask_final = get_mask(src_dem_clip_ds_align, mask_list, src_dem_fn)
                static_mask_final = np.logical_or(np.ma.getmaskarray(diff_align), static_mask_final)
                
                #Final stats, before outlier removal
                diff_align_compressed = diff_align[~static_mask_final]
                diff_align_stats = malib.get_stats_dict(diff_align_compressed, full=True)

                #Prepare filtered version for tiltcorr fit
                diff_align_filt = np.ma.array(diff_align, mask=static_mask_final)
                diff_align_filt = outlier_filter(diff_align_filt, f=3, max_dz=max_dz)
                #diff_align_filt = outlier_filter(diff_align_filt, perc=(12.5, 87.5), max_dz=max_dz)
                slope = get_filtered_slope(src_dem_clip_ds_align)
                diff_align_filt = np.ma.array(diff_align_filt, mask=np.ma.getmaskarray(slope))
                diff_align_filt_stats = malib.get_stats_dict(diff_align_filt, full=True)

            #Fit 2D polynomial to residuals and remove
            #To do: add support for along-track and cross-track artifacts
            if tiltcorr and not tiltcorr_done:
                print("\n************")
                print("Calculating 'tiltcorr' 2D polynomial fit to residuals with order %i" % polyorder)
                print("************\n")
                gt = src_dem_clip_ds_align.GetGeoTransform()

                #Need to apply the mask here, so we're only fitting over static surfaces
                #Note that the origmask=False will compute vals for all x and y indices, which is what we want
                vals, resid, coeff = geolib.ma_fitpoly(diff_align_filt, order=polyorder, gt=gt, perc=(0,100), origmask=False)
                #vals, resid, coeff = geolib.ma_fitplane(diff_align_filt, gt, perc=(12.5, 87.5), origmask=False)

                #Should write out coeff or grid with correction 

                vals_stats = malib.get_stats_dict(vals)

                #Want to have max_tilt check here
                #max_tilt = 4.0 #m
                #Should do percentage
                #vals.ptp() > max_tilt

                #Note: dimensions of ds and vals will be different as vals are computed for clipped intersection
                #Need to recompute planar offset for full src_dem_ds_align extent and apply
                xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
                valgrid = geolib.polyval2d(xgrid, ygrid, coeff) 
                #For results of ma_fitplane
                #valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
                src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)

                if True:
                    print("Creating plot of polynomial fit to residuals")
                    fig, axa = plt.subplots(1,2, figsize=(8, 4))
                    dz_clim = malib.calcperc_sym(vals, (2, 98))
                    ax = pltlib.iv(diff_align_filt, ax=axa[0], cmap='RdBu', clim=dz_clim, \
                            label='Residual dz (m)', scalebar=False)
                    ax = pltlib.iv(valgrid, ax=axa[1], cmap='RdBu', clim=dz_clim, \
                            label='Polyfit dz (m)', ds=src_dem_ds_align)
                    #if tiltcorr:
                        #xyz_shift_str_cum_fn += "_tiltcorr"
                    tiltcorr_fig_fn = outprefix + '%s_polyfit.png' % xyz_shift_str_cum_fn
                    print("Writing out figure: %s\n" % tiltcorr_fig_fn)
                    fig.savefig(tiltcorr_fig_fn, dpi=300)

                print("Applying tilt correction to difference map")
                diff_align -= vals

                #Should iterate until tilts are below some threshold
                #For now, only do one tiltcorr
                tiltcorr_done=True
                #Now use original tolerance, and number of iterations 
                tol = args.tol
                max_iter = n + args.max_iter
            else:
                break

    if True:
        #Write out aligned difference map for clipped extent with vertial offset removed
        align_diff_fn = outprefix + '%s_align_diff.tif' % xyz_shift_str_cum_fn
        print("Writing out aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align, align_diff_fn, src_dem_clip_ds_align)

    if True:
        #Write out fitered aligned difference map
        align_diff_filt_fn = outprefix + '%s_align_diff_filt.tif' % xyz_shift_str_cum_fn
        print("Writing out filtered aligned difference map with median vertical offset removed")
        iolib.writeGTiff(diff_align_filt, align_diff_filt_fn, src_dem_clip_ds_align)

    #Extract final center coordinates for intersection
    center_coord_ll = geolib.get_center(src_dem_clip_ds_align, t_srs=geolib.wgs_srs)
    center_coord_xy = geolib.get_center(src_dem_clip_ds_align)
    src_dem_clip_ds_align = None

    #Write out final aligned src_dem 
    align_fn = outprefix + '%s_align.tif' % xyz_shift_str_cum_fn
    print("Writing out shifted src_dem with median vertical offset removed: %s" % align_fn)
    #Open original uncorrected dataset at native resolution
    src_dem_ds = gdal.Open(src_dem_fn)
    src_dem_ds_align = iolib.mem_drv.CreateCopy('', src_dem_ds, 0)
    #Apply final horizontal and vertial shift to the original dataset
    #Note: potentially issues if we used a different projection during coregistration!
    src_dem_ds_align = coreglib.apply_xy_shift(src_dem_ds_align, dx_total, dy_total, createcopy=False)
    src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, dz_total, createcopy=False)
    if tiltcorr:
        xgrid, ygrid = geolib.get_xy_grids(src_dem_ds_align)
        valgrid = geolib.polyval2d(xgrid, ygrid, coeff) 
        #For results of ma_fitplane
        #valgrid = coeff[0]*xgrid + coeff[1]*ygrid + coeff[2]
        src_dem_ds_align = coreglib.apply_z_shift(src_dem_ds_align, -valgrid, createcopy=False)
    #Might be cleaner way to write out MEM ds directly to disk
    src_dem_full_align = iolib.ds_getma(src_dem_ds_align)
    iolib.writeGTiff(src_dem_full_align, align_fn, src_dem_ds_align)

    if True:
        #Output final aligned src_dem, masked so only best pixels are preserved
        #Useful if creating a new reference product
        #Can also use apply_mask.py 
        print("Applying filter to shiftec src_dem")
        align_diff_filt_full_ds = warplib.memwarp_multi_fn([align_diff_filt_fn,], res=src_dem_ds_align, extent=src_dem_ds_align, \
                t_srs=src_dem_ds_align)[0]
        align_diff_filt_full = iolib.ds_getma(align_diff_filt_full_ds)
        align_diff_filt_full_ds = None
        align_fn_masked = outprefix + '%s_align_filt.tif' % xyz_shift_str_cum_fn
        iolib.writeGTiff(np.ma.array(src_dem_full_align, mask=np.ma.getmaskarray(align_diff_filt_full)), \
                align_fn_masked, src_dem_ds_align)

    src_dem_full_align = None
    src_dem_ds_align = None

    #Compute original elevation difference
    if True:
        ref_dem_clip_ds, src_dem_clip_ds = warplib.memwarp_multi([ref_dem_ds, src_dem_ds], \
                res=res, extent='intersection', t_srs=local_srs, r='cubic')
        src_dem_ds = None
        ref_dem_ds = None
        ref_dem_orig = iolib.ds_getma(ref_dem_clip_ds)
        src_dem_orig = iolib.ds_getma(src_dem_clip_ds)
        #Needed for plotting
        ref_dem_hs = geolib.gdaldem_mem_ds(ref_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        src_dem_hs = geolib.gdaldem_mem_ds(src_dem_clip_ds, processing='hillshade', returnma=True, computeEdges=True)
        diff_orig = src_dem_orig - ref_dem_orig
        #Only compute stats over valid surfaces
        static_mask_orig = get_mask(src_dem_clip_ds, mask_list, src_dem_fn)
        #Note: this doesn't include outlier removal or slope mask!
        static_mask_orig = np.logical_or(np.ma.getmaskarray(diff_orig), static_mask_orig)
        #For some reason, ASTER DEM diff have a spike near the 0 bin, could be an issue with masking?
        diff_orig_compressed = diff_orig[~static_mask_orig]
        diff_orig_stats = malib.get_stats_dict(diff_orig_compressed, full=True)

        #Prepare filtered version for comparison 
        diff_orig_filt = np.ma.array(diff_orig, mask=static_mask_orig)
        diff_orig_filt = outlier_filter(diff_orig_filt, f=3, max_dz=max_dz)
        #diff_orig_filt = outlier_filter(diff_orig_filt, perc=(12.5, 87.5), max_dz=max_dz)
        slope = get_filtered_slope(src_dem_clip_ds)
        diff_orig_filt = np.ma.array(diff_orig_filt, mask=np.ma.getmaskarray(slope))
        diff_orig_filt_stats = malib.get_stats_dict(diff_orig_filt, full=True)

        #Write out original difference map
        print("Writing out original difference map for common intersection before alignment")
        orig_diff_fn = outprefix + '_orig_diff.tif'
        iolib.writeGTiff(diff_orig, orig_diff_fn, ref_dem_clip_ds)
        src_dem_clip_ds = None
        ref_dem_clip_ds = None

    if True:
        align_stats_fn = outprefix + '%s_align_stats.json' % xyz_shift_str_cum_fn
        align_stats = {}
        align_stats['src_fn'] = src_dem_fn 
        align_stats['ref_fn'] = ref_dem_fn 
        align_stats['align_fn'] = align_fn 
        align_stats['res'] = {} 
        align_stats['res']['src'] = src_dem_res
        align_stats['res']['ref'] = ref_dem_res
        align_stats['res']['coreg'] = res
        align_stats['center_coord'] = {'lon':center_coord_ll[0], 'lat':center_coord_ll[1], \
                'x':center_coord_xy[0], 'y':center_coord_xy[1]}
        align_stats['shift'] = {'dx':dx_total, 'dy':dy_total, 'dz':dz_total, 'dm':dm_total}
        #This tiltcorr flag gets set to false, need better flag
        if tiltcorr:
            align_stats['tiltcorr'] = {}
            align_stats['tiltcorr']['coeff'] = coeff.tolist()
            align_stats['tiltcorr']['val_stats'] = vals_stats
        align_stats['before'] = diff_orig_stats
        align_stats['before_filt'] = diff_orig_filt_stats
        align_stats['after'] = diff_align_stats
        align_stats['after_filt'] = diff_align_filt_stats
        
        import json
        with open(align_stats_fn, 'w') as f:
            json.dump(align_stats, f)

    #Create output plot
    if True:
        print("Creating final plot")
        kwargs = {'interpolation':'none'}
        #f, axa = plt.subplots(2, 4, figsize=(11, 8.5))
        f, axa = plt.subplots(2, 4, figsize=(16, 8))
        for ax in axa.ravel()[:-1]:
            ax.set_facecolor('k')
            pltlib.hide_ticks(ax)
        dem_clim = malib.calcperc(ref_dem_orig, (2,98))
        axa[0,0].imshow(ref_dem_hs, cmap='gray', **kwargs)
        im = axa[0,0].imshow(ref_dem_orig, cmap='cpt_rainbow', clim=dem_clim, alpha=0.6, **kwargs)
        pltlib.add_cbar(axa[0,0], im, arr=ref_dem_orig, clim=dem_clim, label=None)
        pltlib.add_scalebar(axa[0,0], res=res)
        axa[0,0].set_title('Reference DEM')
        axa[0,1].imshow(src_dem_hs, cmap='gray', **kwargs)
        im = axa[0,1].imshow(src_dem_orig, cmap='cpt_rainbow', clim=dem_clim, alpha=0.6, **kwargs)
        pltlib.add_cbar(axa[0,1], im, arr=src_dem_orig, clim=dem_clim, label=None)
        axa[0,1].set_title('Source DEM')
        #axa[0,2].imshow(~static_mask_orig, clim=(0,1), cmap='gray')
        axa[0,2].imshow(~static_mask, clim=(0,1), cmap='gray', **kwargs)
        axa[0,2].set_title('Surfaces for co-registration')
        dz_clim = malib.calcperc_sym(diff_orig_compressed, (5, 95))
        im = axa[1,0].imshow(diff_orig, cmap='RdBu', clim=dz_clim)
        pltlib.add_cbar(axa[1,0], im, arr=diff_orig, clim=dz_clim, label=None)
        axa[1,0].set_title('Elev. Diff. Before (m)')
        im = axa[1,1].imshow(diff_align, cmap='RdBu', clim=dz_clim)
        pltlib.add_cbar(axa[1,1], im, arr=diff_align, clim=dz_clim, label=None)
        axa[1,1].set_title('Elev. Diff. After (m)')

        #tight_dz_clim = (-1.0, 1.0)
        tight_dz_clim = (-2.0, 2.0)
        #tight_dz_clim = (-10.0, 10.0)
        #tight_dz_clim = malib.calcperc_sym(diff_align_filt, (5, 95))
        im = axa[1,2].imshow(diff_align_filt, cmap='RdBu', clim=tight_dz_clim)
        pltlib.add_cbar(axa[1,2], im, arr=diff_align_filt, clim=tight_dz_clim, label=None)
        axa[1,2].set_title('Elev. Diff. After (m)')

        #Tried to insert Nuth fig here
        #ax_nuth.change_geometry(1,2,1)
        #f.axes.append(ax_nuth)

        bins = np.linspace(dz_clim[0], dz_clim[1], 128)
        axa[1,3].hist(diff_orig_compressed, bins, color='g', label='Before', alpha=0.5)
        axa[1,3].hist(diff_align_compressed, bins, color='b', label='After', alpha=0.5)
        axa[1,3].set_xlim(*dz_clim)
        axa[1,3].axvline(0, color='k', linewidth=0.5, linestyle=':')
        axa[1,3].set_xlabel('Elev. Diff. (m)')
        axa[1,3].set_ylabel('Count (px)')
        axa[1,3].set_title("Source - Reference")
        before_str = 'Before\nmed: %0.2f\nnmad: %0.2f' % (diff_orig_stats['med'], diff_orig_stats['nmad'])
        axa[1,3].text(0.05, 0.95, before_str, va='top', color='g', transform=axa[1,3].transAxes, fontsize=8)
        after_str = 'After\nmed: %0.2f\nnmad: %0.2f' % (diff_align_stats['med'], diff_align_stats['nmad'])
        axa[1,3].text(0.65, 0.95, after_str, va='top', color='b', transform=axa[1,3].transAxes, fontsize=8)

        #This is empty
        axa[0,3].axis('off')

        suptitle = '%s\nx: %+0.2fm, y: %+0.2fm, z: %+0.2fm' % (os.path.split(outprefix)[-1], dx_total, dy_total, dz_total)
        f.suptitle(suptitle)
        f.tight_layout()
        plt.subplots_adjust(top=0.90)

        fig_fn = outprefix + '%s_align.png' % xyz_shift_str_cum_fn
        print("Writing out figure: %s" % fig_fn)
        f.savefig(fig_fn, dpi=300)
Example #5
0
def main():
    parser = getparser()
    args = parser.parse_args()
    img = args.img
    # populate image list
    img_list = sorted(glob.glob(os.path.join(img, '*.tif')))
    if len(img_list) < 2:
        img_list = sorted(glob.glob(os.path.join(img, '*.tiff')))
        #img_list = [os.path.basename(x) for x in img_list]
        if os.path.islink(img_list[0]):
            img_list = [os.readlink(x) for x in img_list]

    # populate camera model list
    if args.cam:
        cam = os.path.abspath(args.cam)
        if 'run' in os.path.basename(cam):
            cam_list = sorted(glob.glob(cam + '-*.tsai'))
        else:
            cam_list = sorted(glob.glob(os.path.join(cam, '*.tsai')))
        cam_list = cam_list[:len(img_list)]

    session = args.t

    # output ba_prefix
    if args.ba_prefix:
        ba_prefix = os.path.abspath(args.ba_prefix)

    if args.initial_transform:
        initial_transform = os.path.abspath(initial_transform)
    if args.input_adjustments:
        input_adjustments = os.path.abspath(input_adjustments)

    # triplet stereo overlap list
    if args.overlap_list:
        overlap_list = os.path.abspath(args.overlap_list)

    # Populate GCP list
    if args.gcp:
        gcp_list = sorted(glob.glob(os.path.join(args.gcp, '*.gcp')))

    mode = args.mode
    if args.bound:
        bound = gpd.read_file(args.bound)
        geo_crs = {'init': 'epsg:4326'}
        if bound.crs is not geo_crs:
            bound = bound.to_crs(geo_crs)
        lon_min, lat_min, lon_max, lat_max = bound.total_bounds

    # Select whether to float both translation/rotation, or only rotation
    if args.camera_param2float == 'trans+rot':
        cam_wt = 0
    else:
        # this will invoke adjustment with rotation weight of 0 and translation weight of 0.4
        cam_wt = None
    print(f"Camera weight is {cam_wt}")

    # not commonly used
    if args.dem:
        dem = iolib.fn_getma(args.dem)
        dem_stats = malib.get_stats_dict(dem)
        min_elev, max_elev = [dem_stats['min'] - 500, dem_stats['max'] + 500]

    if mode == 'full_video':
        # read subsampled frame index, populate gcp, image and camera models appropriately
        frame_index = args.frame_index
        df = pd.read_csv(frame_index)
        gcp = os.path.abspath(args.gcp)

        # block to determine automatically overlap limit of 40 seconds for computing match points
        df['dt'] = [
            datetime.strptime(date.split('+00:00')[0], '%Y-%m-%dT%H:%M:%S.%f')
            for date in df.datetime.values
        ]
        delta = (df.dt.values[1] - df.dt.values[0]) / np.timedelta64(1, 's')
        # i hardocde overlap limit to have 40 seconds coverage
        overlap_limit = np.int(np.ceil(40 / delta))
        print("Calculated overlap limit as {}".format(overlap_limit))

        img_list = [
            glob.glob(os.path.join(img, '*{}*.tiff'.format(x)))[0]
            for x in df.name.values
        ]
        cam_list = [
            glob.glob(os.path.join(cam, '*{}*.tsai'.format(x)))[0]
            for x in df.name.values
        ]
        gcp_list = [
            glob.glob(os.path.join(gcp, '*{}*.gcp'.format(x)))[0]
            for x in df.name.values
        ]
        #also append the clean gcp here
        print(os.path.join(gcp, '*clean*_gcp.gcp'))
        gcp_list.append(glob.glob(os.path.join(gcp, '*clean*_gcp.gcp'))[0])

        # this attempt did not work here
        # but given videos small footprint, the median (scale)+trans+rotation is good enough for all terrain
        # so reverting back to them
        #stereo_baseline = 10
        #fix_cam_idx = np.array([0]+[0+stereo_baseline])
        #ip_per_tile is switched to default, as die to high scene to scene overlap and limited perspective difference, this produces abundant matches

        round1_opts = get_ba_opts(ba_prefix,
                                  overlap_limit=overlap_limit,
                                  flavor='2round_gcp_1',
                                  session=session,
                                  ip_per_tile=4000,
                                  num_iterations=args.num_iter,
                                  num_pass=args.num_pass,
                                  camera_weight=cam_wt,
                                  fixed_cam_idx=None,
                                  robust_threshold=None)
        print("Running round 1 bundle adjustment for input video sequence")
        if session == 'nadirpinhole':
            ba_args = img_list + cam_list
        else:
            ba_args = img_list
        # Check if this command executed till last
        print('Running bundle adjustment round1')
        run_cmd('bundle_adjust', round1_opts + ba_args)

        # Make files used to evaluate solution quality
        init_residual_fn_def = sorted(
            glob.glob(ba_prefix + '*initial*no_loss_*pointmap*.csv'))[0]
        init_per_cam_reproj_err = sorted(
            glob.glob(
                ba_prefix +
                '-*initial_residuals_no_loss_function_raw_pixels.txt'))[0]
        init_per_cam_reproj_err_disk = os.path.splitext(
            init_per_cam_reproj_err)[0] + '_initial_per_cam_reproj_error.txt'
        init_residual_fn = os.path.splitext(
            init_residual_fn_def)[0] + '_initial_reproj_error.csv'
        shutil.copy2(init_residual_fn_def, init_residual_fn)
        shutil.copy2(init_per_cam_reproj_err, init_per_cam_reproj_err_disk)
        # Copy final reprojection error files before transforming cameras
        final_residual_fn_def = sorted(
            glob.glob(ba_prefix + '*final*no_loss_*pointmap*.csv'))[0]
        final_residual_fn = os.path.splitext(
            final_residual_fn_def)[0] + '_final_reproj_error.csv'
        final_per_cam_reproj_err = sorted(
            glob.glob(ba_prefix +
                      '-*final_residuals_no_loss_function_raw_pixels.txt'))[0]
        final_per_cam_reproj_err_disk = os.path.splitext(
            final_per_cam_reproj_err)[0] + '_final_per_cam_reproj_error.txt'
        shutil.copy2(final_residual_fn_def, final_residual_fn)
        shutil.copy2(final_per_cam_reproj_err, final_per_cam_reproj_err_disk)

        if session == 'nadirpinhole':
            # prepare for second run to apply a constant transform to the self-consistent models using initial ground footprints
            identifier = os.path.basename(cam_list[0]).split(
                df.name.values[0])[0]
            print(ba_prefix + identifier +
                  '-{}*.tsai'.format(df.name.values[0]))
            cam_list = [
                glob.glob(ba_prefix + identifier + '-{}*.tsai'.format(img))[0]
                for img in df.name.values
            ]
            print(len(cam_list))
            ba_args = img_list + cam_list + gcp_list

            #fixed_cam_idx2 = np.delete(np.arange(len(img_list),dtype=int),fix_cam_idx)
            round2_opts = get_ba_opts(ba_prefix,
                                      overlap_limit=overlap_limit,
                                      flavor='2round_gcp_2',
                                      session=session,
                                      gcp_transform=True,
                                      camera_weight=0,
                                      num_iterations=0,
                                      num_pass=1)
        else:
            # round 1 is adjust file
            input_adjustments = ba_prefix
            round2_opts = get_ba_opts(ba_prefix,
                                      overlap_limit=overlap_limit,
                                      input_adjustments=ba_prefix,
                                      flavor='2round_gcp_2',
                                      session=session)
            ba_args = img_list + gcp_list
        print("running round 2 bundle adjustment for input video sequence")
        run_cmd('bundle_adjust', round2_opts + ba_args)

    elif mode == 'full_triplet':
        if args.overlap_list is None:
            print(
                "Attempted bundle adjust will be expensive, will try to find matches in each and every pair"
            )
        # the concept is simple
        #first 3 cameras, and then corresponding first three cameras from next collection are fixed in the first go
        # these serve as a kind of #GCP, preventing a large drift in the triangulated points/camera extrinsics during optimization
        img_time_identifier_list = np.array(
            [os.path.basename(img).split('_')[1] for img in img_list])
        img_time_unique_list = np.unique(img_time_identifier_list)
        second_collection_list = np.where(
            img_time_identifier_list == img_time_unique_list[1])[0][[0, 1, 2]]
        fix_cam_idx = np.array([0, 1, 2] + list(second_collection_list))
        print(type(fix_cam_idx))

        round1_opts = get_ba_opts(ba_prefix,
                                  session=session,
                                  num_iterations=args.num_iter,
                                  num_pass=args.num_pass,
                                  fixed_cam_idx=fix_cam_idx,
                                  overlap_list=args.overlap_list,
                                  camera_weight=cam_wt)
        # enter round2_opts here only ?
        if session == 'nadirpinhole':
            ba_args = img_list + cam_list
        else:
            ba_args = img_list
        print(
            "Running round 1 bundle adjustment for given triplet stereo combination"
        )
        run_cmd('bundle_adjust', round1_opts + ba_args)

        # Save the first and foremost bundle adjustment reprojection error file
        init_residual_fn_def = sorted(
            glob.glob(ba_prefix + '*initial*no_loss_*pointmap*.csv'))[0]
        init_residual_fn = os.path.splitext(
            init_residual_fn_def)[0] + '_initial_reproj_error.csv'
        init_per_cam_reproj_err = sorted(
            glob.glob(
                ba_prefix +
                '-*initial_residuals_no_loss_function_raw_pixels.txt'))[0]
        init_per_cam_reproj_err_disk = os.path.splitext(
            init_per_cam_reproj_err)[0] + '_initial_per_cam_reproj_error.txt'
        shutil.copy2(init_residual_fn_def, init_residual_fn)
        shutil.copy2(init_per_cam_reproj_err, init_per_cam_reproj_err_disk)

        if session == 'nadirpinhole':
            identifier = os.path.basename(cam_list[0]).split('_', 14)[0][:2]
            print(ba_prefix + '-{}*.tsai'.format(identifier))
            cam_list = sorted(
                glob.glob(
                    os.path.join(ba_prefix + '-{}*.tsai'.format(identifier))))
            ba_args = img_list + cam_list
            fixed_cam_idx2 = np.delete(np.arange(len(img_list), dtype=int),
                                       fix_cam_idx)
            round2_opts = get_ba_opts(ba_prefix,
                                      overlap_list=overlap_list,
                                      session=session,
                                      fixed_cam_idx=fixed_cam_idx2,
                                      camera_weight=cam_wt)
        else:
            # round 1 is adjust file
            # Only camera model parameters for the first three stereo pairs float in this round
            input_adjustments = ba_prefix
            round2_opts = get_ba_opts(
                ba_prefix,
                overlap_limit,
                input_adjustments=ba_prefix,
                flavor='2round_gcp_2',
                session=session,
                elevation_limit=[min_elev, max_elev],
                lon_lat_limit=[lon_min, lat_min, lon_max, lat_max])
            ba_args = img_list + gcp_list

        print(
            "running round 2 bundle adjustment for given triplet stereo combination"
        )
        run_cmd('bundle_adjust', round2_opts + ba_args)

        # Save state for final condition reprojection errors for the sparse triangulated points
        final_residual_fn_def = sorted(
            glob.glob(ba_prefix + '*final*no_loss_*pointmap*.csv'))[0]
        final_residual_fn = os.path.splitext(
            final_residual_fn_def)[0] + '_final_reproj_error.csv'
        shutil.copy2(final_residual_fn_def, final_residual_fn)
        final_per_cam_reproj_err = sorted(
            glob.glob(ba_prefix +
                      '-*final_residuals_no_loss_function_raw_pixels.txt'))[0]
        final_per_cam_reproj_err_disk = os.path.splitext(
            final_per_cam_reproj_err)[0] + '_final_per_cam_reproj_error.txt'
        shutil.copy2(final_per_cam_reproj_err, final_per_cam_reproj_err_disk)

        # input is just a transform from pc_align or something similar with no optimization
        if mode == 'transform_pc_align':
            if session == 'nadirpinhole':
                if args.gcp:
                    ba_args = img_list + cam_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
                else:
                    ba_args = img_list + cam_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
            else:
                if args.gcp:
                    ba_args = img_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         initial_transform=initial_transform,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
                else:
                    ba_args = img_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         initial_transform=initial_transform,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
            print("Simply transforming the cameras without optimization")
            run_cmd('bundle_adjust', ba_opt + ba_args, 'Running bundle adjust')

            # general usecase bundle adjust
            if mode == 'general_ba':
                round1_opts = get_ba_opts(ba_prefix,
                                          overlap_limit=args.overlap_limit,
                                          flavor='2round_gcp_1',
                                          session=session)
                print("Running general purpose bundle adjustment")
                if session == 'nadirpinhole':
                    ba_args = img_list + cam_list
                else:
                    ba_args = img_list
                # Check if this command executed till last
                run_cmd('bundle_adjust', round1_opts + ba_args,
                        'Running bundle adjust')
        print("Script is complete !")
Example #6
0
def main():
    parser = getparser()
    args = parser.parse_args()
    img = args.img
    img_list = sorted(glob.glob(os.path.join(img, '*.tif')))
    if len(img_list) < 2:
        img_list = sorted(glob.glob(os.path.join(img, '*.tiff')))
        #img_list = [os.path.basename(x) for x in img_list]
        if os.path.islink(img_list[0]):
            img_list = [os.readlink(x) for x in img_list]
    if args.cam:
        cam = os.path.abspath(args.cam)
        if 'run' in os.path.basename(cam):
            cam_list = sorted(glob.glob(cam + '-*.tsai'))
        else:
            cam_list = sorted(glob.glob(os.path.join(cam, '*.tsai')))
        cam_list = cam_list[:len(img_list)]
    session = args.t
    if args.ba_prefix:
        ba_prefix = args.ba_prefix
    if args.initial_transform:
        initial_transform = os.path.abspath(initial_transform)
    if args.input_adjustments:
        input_adjustments = os.path.abspath(input_adjustments)
    if args.overlap_list:
        overlap_list = os.path.abspath(args.overlap_list)
    if args.gcp:
        gcp_list = sorted(glob.glob(os.path.join(args.gcp, '*.gcp')))
    ba_prefix = os.path.abspath(args.ba_prefix)
    mode = args.mode
    if args.bound:
        bound = gpd.read_file(args.bound)
        geo_crs = {'init': 'epsg:4326'}
        if bound.crs is not geo_crs:
            bound = bound.to_crs(geo_crs)
        lon_min, lat_min, lon_max, lat_max = bound.total_bounds
    if args.dem:
        dem = iolib.fn_getma(args.dem)
        dem_stats = malib.get_stats_dict(dem)
        min_elev, max_elev = [dem_stats['min'] - 500, dem_stats['max'] + 500]
    if mode == 'full_video':
        frame_index = args.frame_index
        df = pd.read_csv(frame_index)
        gcp = os.path.abspath(args.gcp)
        df['dt'] = [
            datetime.strptime(date.split('+00:00')[0], '%Y-%m-%dT%H:%M:%S.%f')
            for date in df.datetime.values
        ]
        delta = (df.dt.values[1] - df.dt.values[0]) / np.timedelta64(1, 's')
        # i hardocde overlap limit to have 40 seconds coverage
        overlap_limit = np.int(np.ceil(40 / delta))
        print(f"Calculated overlap limit as {overlap_limit}")
        img_list = [
            glob.glob(os.path.join(img, f'*{x}*.tiff'))[0]
            for x in df.name.values
        ]
        cam_list = [
            glob.glob(os.path.join(cam, f'*{x}*.tsai'))[0]
            for x in df.name.values
        ]
        gcp_list = [
            glob.glob(os.path.join(gcp, f'*{x}*.gcp'))[0]
            for x in df.name.values
        ]
        #also append the clean gcp here
        print(os.path.join(gcp, '*clean*_gcp.gcp'))
        gcp_list.append(glob.glob(os.path.join(gcp, '*clean*_gcp.gcp'))[0])
        round1_opts = get_ba_opts(ba_prefix,
                                  overlap_limit=overlap_limit,
                                  flavor='2round_gcp_1',
                                  session=session,
                                  num_iterations=args.num_iter)
        print("Running round 1 bundle adjustment for input video sequence")
        if session == 'nadirpinhole':
            ba_args = img_list + cam_list
        else:
            ba_args = img_list
        # Check if this command executed till last
        print('Running bundle adjustment round1')
        #run_cmd('bundle_adjust', round1_opts+ba_args)
        if session == 'nadirpinhole':
            identifier = os.path.basename(cam_list[0]).split(
                df.name.values[0])[0]
            print(ba_prefix + identifier + f'-{df.name.values[0]}*.tsai')
            cam_list = [
                glob.glob(ba_prefix + identifier + f'-{img}*.tsai')[0]
                for img in df.name.values
            ]
            print(len(cam_list))
            ba_args = img_list + cam_list + gcp_list
            round2_opts = get_ba_opts(ba_prefix,
                                      overlap_limit=overlap_limit,
                                      flavor='2round_gcp_2',
                                      session=session,
                                      gcp_transform=True)
        else:
            # round 1 is adjust file
            input_adjustments = ba_prefix
            round2_opts = get_ba_opts(ba_prefix,
                                      overlap_limit=overlap_limit,
                                      input_adjustments=ba_prefix,
                                      flavor='2round_gcp_2',
                                      session=session)
            ba_args = img_list + gcp_list
        print("running round 2 bundle adjustment for input video sequence")
        run_cmd('bundle_adjust', round2_opts + ba_args)
    elif mode == 'full_triplet':
        if args.overlap_list is None:
            print(
                "Attempted bundle adjust will be expensive, will try to find matches in each and every pair"
            )
            round1_opts = get_ba_opts(ba_prefix,
                                      flavor='2round_gcp_1',
                                      session=session,
                                      num_iterations=args.num_iter)
            # enter round2_opts here only ?
        else:
            round1_opts = get_ba_opts(ba_prefix,
                                      overlap_list=overlap_list,
                                      flavor='2round_gcp_1',
                                      session=session,
                                      num_iterations=args.num_iter)
        if session == 'nadirpinhole':
            ba_args = img_list + cam_list
        else:
            ba_args = img_list
        print(
            "Running round 1 bundle adjustment for given triplet stereo combination"
        )
        run_cmd('bundle_adjust', round1_opts + ba_args)
        if session == 'nadirpinhole':
            identifier = os.path.basename(cam_list[0]).split(
                os.path.splitext(os.path.basename(img_list[0]))[0], 2)[0]
            print(ba_prefix + f'-{identifier}*.tsai')
            cam_list = glob.glob(
                os.path.join(ba_prefix + f'-{identifier}*.tsai'))
            ba_args = img_list + cam_list + gcp_list
            round2_opts = get_ba_opts(ba_prefix,
                                      overlap_list=overlap_list,
                                      flavor='2round_gcp_2',
                                      session=session,
                                      gcp_transform=True)
        else:
            # round 1 is adjust file
            input_adjustments = ba_prefix
            round2_opts = get_ba_opts(
                ba_prefix,
                overlap_limit,
                input_adjustments=ba_prefix,
                flavor='2round_gcp_2',
                session=session,
                elevation_limit=[min_elev, max_elev],
                lon_lat_limit=[lon_min, lat_min, lon_max, lat_max])
            ba_args = img_list + gcp_list
        print(
            "running round 2 bundle adjustment for given triplet stereo combination"
        )
        run_cmd('bundle_adjust', round2_opts + ba_args)

        # input is just a transform from pc_align or something similar with no optimization
        if mode == 'transform_pc_align':
            if session == 'nadirpinhole':
                if args.gcp:
                    ba_args = img_list + cam_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
                else:
                    ba_args = img_list + cam_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
            else:
                if args.gcp:
                    ba_args = img_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         initial_transform=initial_transform,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
                else:
                    ba_args = img_list + gcp_list
                    ba_opt = get_ba_opts(ba_prefix,
                                         overlap_list,
                                         initial_transform=initial_transform,
                                         flavor='2round_gcp_2',
                                         session=session,
                                         gcp_transform=True)
            print("Simply transforming the cameras without optimization")
            run_cmd('bundle_adjust', ba_opt + ba_args, 'Running bundle adjust')

            # general usecase bundle adjust
            if mode == 'general_ba':
                round1_opts = get_ba_opts(ba_prefix,
                                          overlap_limit=args.overlap_limit,
                                          flavor='2round_gcp_1',
                                          session=session)
                print("Running general purpose bundle adjustment")
                if session == 'nadirpinhole':
                    ba_args = img_list + cam_list
                else:
                    ba_args = img_list
                # Check if this command executed till last
                run_cmd('bundle_adjust', round1_opts + ba_args,
                        'Running bundle adjust')
        print("Script is complete !")