Ejemplo n.º 1
0
def outupt_video2images(vidpath,
                        img_path,
                        frame_interval=None,
                        if_auto_transpose=True):
    '''
    Output the video frames into images 
    Inputs: 
    vidpath  - the path to the input video 
    img_path - the path to the output image folder 
    frame_interval (optional) - the frame interval 
    '''

    misc.m_makedir(img_path)
    Frame_array = readVideo(vidpath)
    for idx_frame, frame in enumerate(Frame_array):
        if frame_interval is not None:
            if idx_frame % frame_interval == 0:
                print('saving frame %d' % (idx_frame))
                frame = np.asarray(frame)
                if if_auto_transpose is True:
                    if frame.shape[0] > frame.shape[1]:
                        frame = frame.transpose([1, 0, 2])
                plt.imsave('%s/%06d.png' % (img_path, idx_frame), arr=frame)
        else:
            print('saving frame %d' % (idx_frame))
            frame = np.asarray(frame)
            if if_auto_transpose is True:
                if frame.shape[0] > frame.shape[1]:
                    frame = frame.transpose([1, 0, 2])
            plt.imsave('%s/%06d.png' % (img_path, idx_frame), arr=frame)

    return 1
Ejemplo n.º 2
0
def export_res_img( ref_dat, BV_measure, d_candi, resfldr, batch_idx,
                    depth_scale = 1000, conf_scale = 1000):


    # depth map #
    nDepth = len(d_candi)
    dmap_height, dmap_width = BV_measure.shape[2], BV_measure.shape[3] 
    Depth_val_vol = torch.ones(1, nDepth,  dmap_height, dmap_width).cuda()

    for idepth in range(nDepth):
        Depth_val_vol[0, idepth, ...] = Depth_val_vol[0, idepth, ...] * d_candi[idepth]
    dmap_th = depth_regression(Depth_val_vol, BV_measure)
    dmap = torch.FloatTensor(dmap_th).cpu().numpy()

    # confMap #
    confMap_log, _ = torch.max(BV_measure, dim=1)
    confMap_log = torch.exp(confMap_log.squeeze().cpu())
    confMap_log = confMap_log.cpu().numpy()
    confmap = torch.FloatTensor(confMap_log).unsqueeze(0).unsqueeze(0).cuda() 
    confmap = confmap.squeeze().cpu().numpy()
    img = ref_dat['img']
    img = img.squeeze().cpu().permute(1,2,0).numpy()
    img_in_png = _un_normalize( img ); img_in_png = (img_in_png * 255).astype(np.uint8)

    # write to path #
    m_misc.m_makedir(resfldr)
    img_path = '%s/img_%05d.png'%(resfldr, batch_idx)
    d_path = '%s/d_%05d.pgm'%(resfldr, batch_idx)
    conf_path = '%s/conf_%05d.pgm'%(resfldr, batch_idx)

    plt.imsave(img_path, img_in_png)
    imgIO.export2pgm( d_path,    (dmap * depth_scale ).astype(np.uint16) )
    imgIO.export2pgm( conf_path, (confmap * conf_scale ).astype(np.uint16) )
Ejemplo n.º 3
0
def main():

    import argparse 
    print('Parsing arguments ...')
    parser = argparse.ArgumentParser()

    parser.add_argument('--dso_path', required =True, type=str, help='The path to DSO ') 
    parser.add_argument('--data_fldr', required =True, type=str, help='The path to the image folder, where .png or .jpg image files are saved') 
    parser.add_argument('--cam_info_file', required =True, type=str, help='The path to the .mat file saving the camera info') 
    parser.add_argument('--name_pattern', required=True, type=str, help='The name pattern for the image. e.g. *.color.* ')

    # The temp image folder is needed since DSO assumes all images files in one
    # folder are input images. But some datasets include images of multiple
    # types(rgb, depth etc) in one folder. So we will naively copy the input
    # images into the tmp_img_fldr
    parser.add_argument('--temp_img_fldr', required = False, type=str, 
                        default = './dso_imgs', help='The path to the temporary image folder') 
    parser.add_argument( '--res_path', required = False, type=str, 
                        default = './dso_res', help='The path to the DSO output (a txt file containing the camera poses)' ) 
    parser.add_argument('--minframe', required=False, type=int, default = 0, help='starting frame idx')
    parser.add_argument('--maxframe', required=False, type=int, default = 100, help='ending frame idx') 

    args = parser.parse_args()

    minframe = args.minframe
    maxframe = args.maxframe

    res_fldr = args.res_path
    cam_info = sio.loadmat(args.cam_info_file)
    intrinsic_info = { 'IntM': cam_info['IntM'],  } 

    # create temp folder #
    print('copying the images..')
    img_fldr_path = str(args.temp_img_fldr)
    m_misc.m_makedir(img_fldr_path)
    files = sorted( glob.glob('%s/%s'%( args.data_fldr, args.name_pattern)) )
    m_misc.m_makedir( res_fldr ) 

    for f,i in zip(files, range(0, args.maxframe)):
        shutil.copy(f, img_fldr_path) # move image files into the temp folder# 

    # run DSO #
    im = Image.open( files[0] )
    intrinsic_info['img_size'] = im.size
     
    Rts_cam_to_world = dso_io.run_DSO(img_fldr_path = img_fldr_path, 
                                      dso_bin_path = args.dso_path,
                                      intrinsic_info = intrinsic_info, 
                                      result_path = '%s/result_dso.txt'%( res_fldr ) ,
                                      vig_img_path = './DSO/vignette.png',
                                      min_frame= minframe, max_frame= maxframe,
                                      mode = 1, preset = 2, nogui = 1, use_existing = False ) 

    # delete temp fldr #
    shutil.rmtree( img_fldr_path ) 
    
    print('\n\n# of frames with pose estimated: %d'%(len(Rts_cam_to_world)))
    print('# of frames in the video: %d'%(len(files)))
Ejemplo n.º 4
0
def export_res_img(ref_dat,
                   BV_measure,
                   d_candi,
                   resfldr,
                   batch_idx,
                   depth_scale=1000,
                   conf_scale=1000):

    # depth map #
    nDepth = len(d_candi)
    dmap_height, dmap_width = BV_measure.shape[2], BV_measure.shape[3]
    Depth_val_vol = torch.ones(1, nDepth, dmap_height, dmap_width).cuda()

    for idepth in range(nDepth):
        Depth_val_vol[0, idepth,
                      ...] = Depth_val_vol[0, idepth, ...] * d_candi[idepth]
    dmap_th = depth_regression(Depth_val_vol, BV_measure)
    dmap = torch.FloatTensor(dmap_th).cpu().numpy()  ## pred_depth

    # confMap #
    confMap_log, _ = torch.max(BV_measure, dim=1)
    confMap_log = torch.exp(confMap_log.squeeze().cpu())
    confMap_log = confMap_log.cpu().numpy()
    confmap = torch.FloatTensor(confMap_log).unsqueeze(0).unsqueeze(0).cuda()
    confmap = confmap.squeeze().cpu().numpy()
    img = ref_dat['img']
    img = img.squeeze().cpu().permute(1, 2, 0).numpy()
    img_in_png = _un_normalize(img)
    img_in_png = (img_in_png * 255).astype(np.uint8)

    # write to path #
    m_misc.m_makedir(resfldr)
    img_path = '%s/img_%05d.png' % (resfldr, batch_idx)
    d_path = '%s/d_%05d.pgm' % (resfldr, batch_idx)
    conf_path = '%s/conf_%05d.pgm' % (resfldr, batch_idx)
    d_vis_path = '%s/d_vis_%05d.png' % (resfldr, batch_idx)  ### add

    plt.imsave(img_path, img_in_png)
    # plt.imsave(d_vis_path, 1./ dmap, cmap='plasma')  ### add
    imgIO.export2pgm(d_path, (dmap * depth_scale).astype(np.uint16))
    imgIO.export2pgm(conf_path, (confmap * conf_scale).astype(np.uint16))

    gt = ref_dat['dmap_imgsize']
    mask = np.logical_and(gt > MIN_DEPTH, gt < MAX_DEPTH)

    ratio = np.median(gt[mask]) / np.median(dmap[mask])
    dmap *= ratio

    dmap[dmap < MIN_DEPTH] = MIN_DEPTH
    dmap[dmap > MAX_DEPTH] = MAX_DEPTH

    return torch.tensor(dmap).unsqueeze(0), ref_dat['dmap_imgsize']
Ejemplo n.º 5
0
def main():
    import argparse
    print('Parsing the arguments...')
    parser = argparse.ArgumentParser()

    # exp name #
    parser.add_argument(
        '--exp_name',
        required=True,
        type=str,
        help='The name of the experiment. Used to naming the folders')

    # about testing #
    parser.add_argument('--model_path',
                        type=str,
                        required=True,
                        help='The pre-trained model path for KV-net')
    parser.add_argument('--split_file',
                        type=str,
                        required=True,
                        help='The split txt file')
    parser.add_argument('--frame_interv',
                        default=5,
                        type=int,
                        help='frame interval')
    parser.add_argument('--t_win',
                        type=int,
                        default=2,
                        help='The radius of the temporal window; default=2')
    parser.add_argument('--d_min',
                        type=float,
                        default=0,
                        help='The minimal depth value; default=0')
    parser.add_argument('--d_max',
                        type=float,
                        default=5,
                        help='The maximal depth value; default=15')
    parser.add_argument('--ndepth',
                        type=int,
                        default=64,
                        help='The # of candidate depth values; default= 128')
    parser.add_argument('--sigma_soft_max',
                        type=float,
                        default=10.,
                        help='sigma_soft_max, default = 500.')
    parser.add_argument(
        '--feature_dim',
        type=int,
        default=64,
        help='The feature dimension for the feature extractor; default=64')

    # about dataset #
    parser.add_argument('--dataset',
                        type=str,
                        default='scanNet',
                        help='Dataset name: {scanNet, 7scenes, kitti}')
    parser.add_argument('--dataset_path',
                        type=str,
                        default='.',
                        help='Path to the dataset')
    parser.add_argument(
        '--change_aspect_ratio',
        action='store_true',
        default=False,
        help=
        'If we want to change the aspect ratio. This option is only useful for KITTI'
    )

    # parsing parameters #
    args = parser.parse_args()
    exp_name = args.exp_name
    dataset_name = args.dataset
    t_win_r = args.t_win
    nDepth = args.ndepth
    d_candi = np.linspace(args.d_min, args.d_max, nDepth)
    sigma_soft_max = args.sigma_soft_max  #10.#500.
    dnet_feature_dim = args.feature_dim
    frame_interv = args.frame_interv  # should be multiple of 5 for scanNet dataset
    d_upsample = None
    d_candi_dmap_ref = d_candi
    nDepth_dmap_ref = nDepth

    # ===== Dataset selection ======== #
    dataset_path = args.dataset_path
    if dataset_name == 'scanNet':
        import mdataloader.scanNet as dl_scanNet
        dataset_init = dl_scanNet.ScanNet_dataset
        fun_get_paths = lambda traj_indx: dl_scanNet.get_paths(
            traj_indx,
            frame_interv=5,
            split_txt=split_file,
            database_path_base=dataset_path)
        img_size = [384, 256]

        # trajectory index for testing #
        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    elif dataset_name == '7scenes':
        # 7 scenes video #
        import mdataloader.dl_7scenes as dl_7scenes
        dataset_init = dl_7scenes.SevenScenesDataset
        dat_indx_step = 3

        split_file = None if args.split_file == '.' else args.split_file
        fun_get_paths = lambda traj_indx: dl_7scenes.get_paths_1frame(\
                            traj_indx, database_path_base = dataset_path , split_txt = split_file,
                            dat_indx_step = dat_indx_step)

        img_size = [384, 256]
        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    elif dataset_name == 'kitti':
        import mdataloader.kitti as dl_kitti
        dataset_init = dl_kitti.KITTI_dataset
        fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
            traj_indx, split_txt=split_file, mode='val')
        if not dataset_path == '.':
            fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
                traj_indx,
                split_txt=split_file,
                mode='val',
                database_path_base=dataset_path)
        else:  # use default database path
            fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
                traj_indx, split_txt=split_file, mode='val')
        if not args.change_aspect_ratio:  # we will keep the aspect ratio and do cropping
            img_size = [768, 256]
            crop_w = None
        else:  # we will change the aspect ratio and NOT do cropping
            img_size = [384, 256]
            crop_w = None

        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    else:
        raise Exception('dataset loader not implemented')

    fldr_path, img_paths, dmap_paths, poses, intrin_path = fun_get_paths(0)
    if dataset_name == 'kitti':
        dataset = dataset_init(True,
                               img_paths,
                               dmap_paths,
                               poses,
                               intrin_path=intrin_path,
                               img_size=img_size,
                               digitize=True,
                               d_candi=d_candi_dmap_ref,
                               resize_dmap=.25,
                               crop_w=crop_w)

        dataset_imgsize = dataset_init(True,
                                       img_paths,
                                       dmap_paths,
                                       poses,
                                       intrin_path=intrin_path,
                                       img_size=img_size,
                                       digitize=True,
                                       d_candi=d_candi_dmap_ref,
                                       resize_dmap=1)
    else:
        dataset = dataset_init(True,
                               img_paths,
                               dmap_paths,
                               poses,
                               intrin_path=intrin_path,
                               img_size=img_size,
                               digitize=True,
                               d_candi=d_candi_dmap_ref,
                               resize_dmap=.25)

        dataset_imgsize = dataset_init(True,
                                       img_paths,
                                       dmap_paths,
                                       poses,
                                       intrin_path=intrin_path,
                                       img_size=img_size,
                                       digitize=True,
                                       d_candi=d_candi_dmap_ref,
                                       resize_dmap=1)
    # ================================ #

    print('Initnializing the KV-Net')
    model_KVnet = m_kvnet.KVNET(feature_dim=dnet_feature_dim,
                                cam_intrinsics=dataset.cam_intrinsics,
                                d_candi=d_candi,
                                sigma_soft_max=sigma_soft_max,
                                KVNet_feature_dim=dnet_feature_dim,
                                d_upsample_ratio_KV_net=d_upsample,
                                t_win_r=t_win_r,
                                if_refined=True)

    model_KVnet = torch.nn.DataParallel(model_KVnet)
    model_KVnet.cuda()

    model_path_KV = args.model_path
    print('loading KV_net at %s' % (model_path_KV))
    utils_model.load_pretrained_model(model_KVnet, model_path_KV)
    print('Done')

    for traj_idx in traj_Indx:
        res_fldr = '../results/%s/traj_%d' % (exp_name, traj_idx)
        m_misc.m_makedir(res_fldr)
        scene_path_info = []

        print('Getting the paths for traj_%d' % (traj_idx))
        fldr_path, img_seq_paths, dmap_seq_paths, poses, intrin_path = fun_get_paths(
            traj_idx)
        dataset.set_paths(img_seq_paths, dmap_seq_paths, poses)

        if dataset_name is 'scanNet':
            # For each trajector in the dataset, we will update the intrinsic matrix #
            dataset.get_cam_intrinsics(intrin_path)

        print('Done')
        dat_array = [dataset[idx] for idx in range(t_win_r * 2 + 1)]
        DMaps_meas = []
        traj_length = len(dataset)
        print('trajectory length = %d' % (traj_length))

        for frame_cnt, ref_indx in enumerate(
                range(t_win_r, traj_length - t_win_r - 1)):
            eff_iter = True
            valid_seq = check_datArray_pose(dat_array)

            # Read ref. and src. data in the local time window #
            ref_dat, src_dats = m_misc.split_frame_list(dat_array, t_win_r)

            if frame_cnt == 0:
                BVs_predict = None

            if valid_seq and eff_iter:
                # Get poses #
                src_cam_extMs = m_misc.get_entries_list_dict(src_dats, 'extM')
                src_cam_poses = \
                        [warp_homo.get_rel_extrinsicM(ref_dat['extM'], src_cam_extM_) \
                        for src_cam_extM_ in src_cam_extMs ]

                src_cam_poses = [
                    torch.from_numpy(pose.astype(
                        np.float32)).cuda().unsqueeze(0)
                    for pose in src_cam_poses
                ]

                # src_cam_poses size: N V 4 4 #
                src_cam_poses = torch.cat(src_cam_poses, dim=0).unsqueeze(0)
                src_frames = [m_misc.get_entries_list_dict(src_dats, 'img')]

                if frame_cnt == 0 or BVs_predict is None:  # the first window for the traj.
                    BVs_predict_in = None
                else:
                    BVs_predict_in = BVs_predict

                print('testing on %d/%d frame in traj %d/%d ... '%\
                       (frame_cnt+1, traj_length - 2*t_win_r, traj_idx+1, len(traj_Indx)) )

                # set trace for specific frame #
                BVs_measure, BVs_predict = test_KVNet.test(
                    model_KVnet,
                    d_candi,
                    Ref_Dats=[ref_dat],
                    Src_Dats=[src_dats],
                    Cam_Intrinsics=[dataset.cam_intrinsics],
                    t_win_r=t_win_r,
                    Src_CamPoses=src_cam_poses,
                    BV_predict=BVs_predict_in,
                    R_net=True,
                    Cam_Intrinsics_imgsize=dataset_imgsize.cam_intrinsics,
                    ref_indx=ref_indx)

                # export_res.export_res_refineNet(ref_dat,  BVs_measure, d_candi_dmap_ref,
                #                                 res_fldr, ref_indx,
                #                                 save_mat = True, output_pngs = False, output_dmap_ref=False)
                export_res.export_res_img(ref_dat, BVs_measure,
                                          d_candi_dmap_ref, res_fldr,
                                          frame_cnt)
                scene_path_info.append(
                    [frame_cnt, dataset[ref_indx]['img_path']])

            elif valid_seq is False:  # if the sequence contains invalid pose estimation
                BVs_predict = None
                print('frame_cnt :%d, include invalid poses' % (frame_cnt))

            elif eff_iter is False:
                BVs_predict = None

            # Update dat_array #
            dat_array.pop(0)
            dat_array.append(dataset[ref_indx + t_win_r + 1])

        m_misc.save_ScenePathInfo('%s/scene_path_info.txt' % (res_fldr),
                                  scene_path_info)
Ejemplo n.º 6
0
def export_res_refineNet(ref_dat, BV_measure, d_candi,  res_fldr, batch_idx, diff_vrange_ratio=4, 
        cam_pose = None, cam_intrinM = None, output_pngs = False, save_mat=True, output_dmap_ref=True):
    '''
    export results
    ''' 

    # img_in #
    img_up = ref_dat['img']
    img_in_raw = img_up.squeeze().cpu().permute(1,2,0).numpy()
    img_in = (img_in_raw - img_in_raw.min()) / (img_in_raw.max()-img_in_raw.min()) * 255.

    # confMap #
    confMap_log, _ = torch.max(BV_measure, dim=1)
    confMap_log = torch.exp(confMap_log.squeeze().cpu())
    confMap_log = confMap_log.cpu().numpy()

    # depth map #
    nDepth = len(d_candi)
    dmap_height, dmap_width = BV_measure.shape[2], BV_measure.shape[3] 
    dmap = m_misc.depth_val_regression(BV_measure, d_candi, BV_log = True).squeeze().cpu().numpy() 

    # save up-sampeled results #
    resfldr = res_fldr 
    m_misc.m_makedir(resfldr)

    img_up_path ='%s/input.png'%(resfldr,)
    conf_up_path = '%s/conf.png'%(resfldr,)
    dmap_raw_path = '%s/dmap_raw.png'%(resfldr,)
    final_res_up = '%s/res_%05d.png'%(resfldr, batch_idx) 

    if output_dmap_ref: # output GT depth
        ref_up = '%s/dmap_ref.png'%(resfldr,)
        res_up_diff = '%s/dmaps_diff.png'%(resfldr,)
        dmap_ref = ref_dat['dmap_imgsize']
        dmap_ref = dmap_ref.squeeze().cpu().numpy() 
        mask_dmap = (dmap_ref > 0 ).astype(np.float)
        dmap_diff_raw = np.abs(dmap_ref - dmap ) * mask_dmap
        dmaps_diff = dmap_diff_raw 
        plt.imsave(res_up_diff, dmaps_diff, vmin=0, vmax=d_candi.max()/ diff_vrange_ratio )
        plt.imsave(ref_up, dmap_ref, vmax= d_candi.max(), vmin=0, cmap='gray')

    plt.imsave(conf_up_path, confMap_log, vmin=0, vmax=1, cmap='jet')
    plt.imsave(dmap_raw_path, dmap, vmin=0., vmax =d_candi.max(), cmap='gray' )
    plt.imsave(img_up_path, img_in.astype(np.uint8))

    # output the depth as .mat files # 
    fname_mat = '%s/depth_%05d.mat'%(resfldr, batch_idx)
    img_path = ref_dat['img_path'] 
    if save_mat:
        if not output_dmap_ref:
            mdict = { 'dmap': dmap, 'img': img_in_raw, 'confMap': confMap_log, 'img_path': img_path}
        elif cam_pose is None:
            mdict = {'dmap_ref': dmap_ref, 'dmap': dmap, 'img': img_in_raw, 'confMap': confMap_log,
                    'img_path':   img_path}
        else:
            mdict = {'dmap_ref': dmap_ref, 'dmap': dmap, 
                    'img': img_in_raw, 'cam_pose': cam_pose, 
                    'confMap':confMap_log, 'cam_intrinM': cam_intrinM, 
                    'img_path': img_path } 
        sio.savemat(fname_mat, mdict) 

    print('export to %s'%(final_res_up))
    
    if output_dmap_ref:
        cat_imgs((img_up_path, conf_up_path, dmap_raw_path, res_up_diff, ref_up), final_res_up) 
    else:
        cat_imgs((img_up_path, conf_up_path, dmap_raw_path), final_res_up) 

    if output_pngs:
        import cv2
        png_fldr = '%s/output_pngs'%(res_fldr, )
        m_misc.m_makedir( png_fldr ) 
        depth_png = (dmap * 1000 ).astype(np.uint16)
        img_in_png = _un_normalize( img_in_raw ); img_in_png = (img_in_png * 255).astype(np.uint8)
        confMap_png = (confMap_log*255).astype(np.uint8) 
        cv2.imwrite( '%s/d_%05d.png'%(png_fldr, batch_idx), depth_png)
        cv2.imwrite( '%s/rgb_%05d.png'%(png_fldr, batch_idx), img_in_png)
        cv2.imwrite( '%s/conf_%05d.png'%(png_fldr, batch_idx), confMap_png)

        if output_dmap_ref:
            depth_ref_png = (dmap_ref * 1000).astype(np.uint16) 
            cv2.imwrite( '%s/dref_%05d.png'%(png_fldr, batch_idx), depth_ref_png)
Ejemplo n.º 7
0
def main():
    import argparse
    print('Parsing the arguments...')
    parser = argparse.ArgumentParser()

    # exp name #
    parser.add_argument(
        '--exp_name',
        required=True,
        type=str,
        help='The name of the experiment. Used to naming the folders')

    # about testing #
    parser.add_argument('--img_name_pattern',
                        type=str,
                        default='*.png',
                        help='image name pattern')
    parser.add_argument('--model_path',
                        type=str,
                        default='.',
                        help='The pre-trained model path for KV-net')
    parser.add_argument('--split_file',
                        type=str,
                        default='.',
                        help='The split txt file')
    parser.add_argument('--t_win',
                        type=int,
                        default=2,
                        help='The radius of the temporal window; default=2')
    parser.add_argument('--d_min',
                        type=float,
                        default=0,
                        help='The minimal depth value; default=0')
    parser.add_argument('--d_max',
                        type=float,
                        default=5,
                        help='The maximal depth value; default=15')
    parser.add_argument('--ndepth',
                        type=int,
                        default=64,
                        help='The # of candidate depth values; default= 128')
    parser.add_argument('--sigma_soft_max',
                        type=float,
                        default=10.,
                        help='sigma_soft_max, default = 500.')
    parser.add_argument(
        '--feature_dim',
        type=int,
        default=64,
        help='The feature dimension for the feature extractor; default=64')

    # about pose #
    parser.add_argument('--intrin_path',
                        type=str,
                        required=True,
                        help='camera intrinic path, saved as .mat')

    parser.add_argument(
        '--dso_res_path',
        type=str,
        default='dso_res/result_dso.txt',
        help=
        'if use DSO pose, specify the path to the DSO results. Should be a .txt file'
    )
    parser.add_argument('--opt_next_frame', action='store_true', help='')
    parser.add_argument('--use_gt_R', action='store_true', help='')
    parser.add_argument('--use_gt_t', action='store_true', help='')
    parser.add_argument('--use_dso_R', action='store_true', help='')
    parser.add_argument('--use_dso_t', action='store_true', help='')
    parser.add_argument('--min_frame_idx', type=int, help=' ', default=0)
    parser.add_argument('--max_frame_idx', type=int, help=' ', default=10000)
    parser.add_argument('--refresh_frames', type=int, help=' ', default=1000)
    parser.add_argument('--LBA_max_iter', type=int, help=' ')
    parser.add_argument('--opt_r', type=int, default=1, help=' ')
    parser.add_argument('--opt_t', type=int, default=1, help=' ')
    parser.add_argument('--LBA_step', type=float, help=' ')
    parser.add_argument('--frame_interv', type=int, default=5, help=' ')

    # about dataset #
    parser.add_argument('--dataset',
                        type=str,
                        default='7scenes',
                        help='Dataset name: {scanNet, 7scenes}')
    parser.add_argument('--dataset_path',
                        type=str,
                        default='.',
                        help='Path to the dataset')

    # about output #
    parser.add_argument('--output_pngs',
                        action='store_true',
                        help='if output pngs')

    # para config. #
    args = parser.parse_args()
    exp_name = args.exp_name
    dataset_name = args.dataset
    t_win_r = args.t_win
    nDepth = args.ndepth

    d_candi = np.linspace(args.d_min, args.d_max, nDepth)

    sigma_soft_max = args.sigma_soft_max  #10.#500.
    dnet_feature_dim = args.feature_dim
    frame_interv = args.frame_interv
    d_candi_dmap_ref = d_candi
    nDepth_dmap_ref = nDepth

    # Initialize data-loader, model and optimizer #
    # ===== Dataset selection ======== #
    dataset_path = args.dataset_path
    if dataset_name == 'scanNet':
        #  deal with 1-frame scanNet data
        import mdataloader.scanNet as dl_scanNet
        dataset_init = dl_scanNet.ScanNet_dataset
        split_txt = './mdataloader/scanNet_split/scannet_val.txt' if args.split_file == '.' else args.split_file
        if not dataset_path == '.':
            # if specify the path, we will assume we are using 1-frame-interval scanNet video #
            fun_get_paths = lambda traj_indx: dl_scanNet.get_paths_1frame(
                traj_indx,
                database_path_base=dataset_path,
                split_txt=split_txt)
            dat_indx_step = 5  #pick this value to make sure the camera baseline is big enough
        else:
            fun_get_paths = lambda traj_indx: dl_scanNet.get_paths(
                traj_indx, frame_interv=5, split_txt=split_txt)
            dat_indx_step = 1
        img_size = [384, 256]
        # trajectory index for training #
        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    elif dataset_name == '7scenes':
        # 7 scenes video #
        import mdataloader.dl_7scenes as dl_7scenes
        img_size = [384, 256]
        dataset_init = dl_7scenes.SevenScenesDataset
        dat_indx_step = 5  # pick this value to make sure the camera baseline is big enough
        # trajectory index for training #
        split_file = None if args.split_file == '.' else args.split_file
        fun_get_paths = lambda traj_indx: dl_7scenes.get_paths_1frame(
            traj_indx,
            database_path_base=dataset_path,
            split_txt=split_file,
        )

    elif dataset_name == 'single_folder':
        # images in a single folder specified by the user #
        import mdataloader.mdata as mdata
        img_size = [384, 256]
        dataset_init = mdata.mData
        dat_indx_step = 5  # pick this value to make sure the camera baseline is big enough
        fun_get_paths = lambda traj_indx: mdata.get_paths_1frame(
            traj_indx, dataset_path, args.img_name_pattern)
        traj_Indx = [0]  #dummy

    fldr_path, img_paths, dmap_paths, poses, intrin_path = fun_get_paths(
        traj_Indx[0])

    if dataset_name == 'single_folder':
        intrin_path = args.intrin_path

    dataset = dataset_init(
        True,
        img_paths,
        dmap_paths,
        poses,
        intrin_path=intrin_path,
        img_size=img_size,
        digitize=True,
        d_candi=d_candi_dmap_ref,
        resize_dmap=.25,
    )

    dataset_Himgsize = dataset_init(
        True,
        img_paths,
        dmap_paths,
        poses,
        intrin_path=intrin_path,
        img_size=img_size,
        digitize=True,
        d_candi=d_candi_dmap_ref,
        resize_dmap=.5,
    )

    dataset_imgsize = dataset_init(
        True,
        img_paths,
        dmap_paths,
        poses,
        intrin_path=intrin_path,
        img_size=img_size,
        digitize=True,
        d_candi=d_candi_dmap_ref,
        resize_dmap=1,
    )

    # ================================ #

    print('Initnializing the KV-Net')
    model_KVnet = m_kvnet.KVNET(\
            feature_dim = dnet_feature_dim,
            cam_intrinsics = dataset.cam_intrinsics,
            d_candi = d_candi, sigma_soft_max = sigma_soft_max,
            KVNet_feature_dim = dnet_feature_dim,
            d_upsample_ratio_KV_net = None,
            t_win_r = t_win_r, if_refined = True)

    model_KVnet = torch.nn.DataParallel(model_KVnet)
    model_KVnet.cuda()

    model_path_KV = args.model_path
    print('loading KV_net at %s' % (model_path_KV))
    utils_model.load_pretrained_model(model_KVnet, model_path_KV)
    print('Done')

    for traj_idx in traj_Indx:
        scene_path_info = []
        print('Getting the paths for traj_%d' % (traj_idx))
        fldr_path, img_seq_paths, dmap_seq_paths, poses, intrin_path = fun_get_paths(
            traj_idx)
        res_fldr = '../results/%s/traj_%d' % (exp_name, traj_idx)
        m_misc.m_makedir(res_fldr)

        dataset.set_paths(img_seq_paths, dmap_seq_paths, poses)

        if dataset_name == 'scanNet':
            # the camera intrinsic may be slightly different for different trajectories in scanNet #
            dataset.get_cam_intrinsics(intrin_path)

        print('Done')
        if args.min_frame_idx > 0:
            frame_idxs = np.arange(args.min_frame_idx - t_win_r,
                                   args.min_frame_idx + t_win_r)
            dat_array = [dataset[idx] for idx in frame_idxs]
        else:
            dat_array = [dataset[idx] for idx in range(t_win_r * 2 + 1)]

        DMaps_meas = []
        dso_res_path = args.dso_res_path

        print('init initial pose from DSO estimations ...')
        traj_extMs = init_traj_extMs(traj_len=len(dataset),
                                     dso_res_path=dso_res_path,
                                     if_filter=True,
                                     min_idx=args.min_frame_idx,
                                     max_idx=args.max_frame_idx)
        traj_extMs_init = copy_list(traj_extMs)
        traj_length = min(len(dataset), len(traj_extMs))
        first_frame = True
        for frame_cnt, ref_indx in enumerate(
                range(t_win_r * dat_indx_step + args.min_frame_idx,
                      traj_length - t_win_r * dat_indx_step - dat_indx_step)):
            # ref_indx: the frame index for the reference frame #

            # Read ref. and src. data in the local time window #
            ref_dat, src_dats = m_misc.split_frame_list(dat_array, t_win_r)

            src_frame_idx = [ idx for idx in range(
                ref_indx - t_win_r * dat_indx_step, ref_indx, dat_indx_step) ] + \
                            [ idx for idx in range(
                 ref_indx + dat_indx_step, ref_indx + t_win_r*dat_indx_step+1, dat_indx_step) ]

            valid_seq = dso_io.valid_poses(traj_extMs, src_frame_idx)

            # only look at a subset of frames #
            if ref_indx < args.min_frame_idx:
                valid_seq = False
            if ref_indx > args.max_frame_idx or ref_indx >= traj_length - t_win_r * dat_indx_step - dat_indx_step:
                break
            if frame_cnt == 0 or valid_seq is False:
                BVs_predict = None

            # refresh #
            if ref_indx % args.refresh_frames == 0:
                print('REFRESH !')
                BVs_predict = None
                BVs_predict_in = None
                first_frame = True
                traj_extMs = copy_list(traj_extMs_init)

            if valid_seq:  # if the sequence does not contain invalid pose estimation
                # Get poses #
                src_cam_extMs = [traj_extMs[i] for i in src_frame_idx]
                ref_cam_extM = traj_extMs[ref_indx]
                src_cam_poses = [
                    warp_homo.get_rel_extrinsicM(ref_cam_extM, src_cam_extM_)
                    for src_cam_extM_ in src_cam_extMs
                ]
                src_cam_poses = [
                    torch.from_numpy(pose.astype(
                        np.float32)).cuda().unsqueeze(0)
                    for pose in src_cam_poses
                ]

                # Load the gt pose if available #
                if 'extM' in dataset[0]:
                    src_cam_extMs_ref = [
                        dataset[i]['extM'] for i in src_frame_idx
                    ]
                    ref_cam_extM_ref = dataset[ref_indx]['extM']
                    src_cam_poses_ref = [ warp_homo.get_rel_extrinsicM(ref_cam_extM_ref, src_cam_extM_) \
                                         for src_cam_extM_ in src_cam_extMs_ref ]
                    src_cam_poses_ref = [ torch.from_numpy(pose.astype(np.float32)).cuda().unsqueeze(0) \
                                         for pose in src_cam_poses_ref ]

                # -- Determine the scale, mapping from DSO scale to our working scale -- #
                if frame_cnt == 0 or BVs_predict is None:  # the first window for the traj.
                    _, t_norm_single = get_fb(src_cam_poses,
                                              dataset.cam_intrinsics,
                                              src_cam_pose_next=None)
                    # We need to heurisitcally determine scale_ without using GT pose #
                    t_norms = get_t_norms(traj_extMs, dat_indx_step)
                    scale_ = d_candi.max() / (
                        dataset.cam_intrinsics['focal_length'] *
                        np.array(t_norm_single).max() / 2)
                    scale_ = d_candi.max() / (
                        dataset.cam_intrinsics['focal_length'] *
                        np.array(t_norms).max())
                    scale_ = d_candi.max() / (
                        dataset.cam_intrinsics['focal_length'] *
                        np.array(t_norms).mean() / 2)
                    rescale_traj_t(traj_extMs, scale_)
                    traj_extMs_dso = copy_list(traj_extMs)
                    # Get poses #
                    src_cam_extMs = [traj_extMs[i] for i in src_frame_idx]
                    ref_cam_extM = traj_extMs[ref_indx]
                    src_cam_poses = [
                        warp_homo.get_rel_extrinsicM(ref_cam_extM,
                                                     src_cam_extM_)
                        for src_cam_extM_ in src_cam_extMs
                    ]
                    src_cam_poses = [
                        torch.from_numpy(pose.astype(
                            np.float32)).cuda().unsqueeze(0)
                        for pose in src_cam_poses
                    ]

                # src_cam_poses size: N V 4 4 #
                src_cam_poses = torch.cat(src_cam_poses, dim=0).unsqueeze(0)
                src_frames = [m_misc.get_entries_list_dict(src_dats, 'img')]
                cam_pose_next = traj_extMs[ref_indx + 1]
                cam_pose_next = torch.FloatTensor(
                    warp_homo.get_rel_extrinsicM(traj_extMs[ref_indx],
                                                 cam_pose_next)).cuda()

                BVs_predict_in = None if frame_cnt == 0 or BVs_predict is None \
                                      else BVs_predict

                BVs_measure, BVs_predict = test_KVNet.test(
                    model_KVnet,
                    d_candi,
                    Ref_Dats=[ref_dat],
                    Src_Dats=[src_dats],
                    Cam_Intrinsics=[dataset.cam_intrinsics],
                    t_win_r=t_win_r,
                    Src_CamPoses=src_cam_poses,
                    BV_predict=BVs_predict_in,
                    R_net=True,
                    cam_pose_next=cam_pose_next,
                    ref_indx=ref_indx)

                # export_res.export_res_refineNet(ref_dat,  BVs_measure, d_candi_dmap_ref,
                #                                 res_fldr, ref_indx,
                #                                 save_mat = True, output_pngs = args.output_pngs, output_dmap_ref=False)
                export_res.export_res_img(ref_dat, BVs_measure,
                                          d_candi_dmap_ref, res_fldr,
                                          frame_cnt)
                scene_path_info.append(
                    [frame_cnt, dataset[ref_indx]['img_path']])

                # UPDATE dat_array #
                if dat_indx_step > 1:  # use one-interval video and the frame interval is larger than 5
                    print('updating array ...')
                    dat_array = update_dat_array(dat_array,
                                                 dataset,
                                                 data_interv=1,
                                                 frame_interv=5,
                                                 ref_indx=ref_indx,
                                                 t_win_r=t_win_r)
                    print('done')

                else:
                    dat_array.pop(0)
                    new_dat = dataset[ref_indx + t_win_r + 1]
                    dat_array.append(new_dat)

                # OPTMIZE POSES #
                idx_ref_ = ref_indx + 1
                cam_pose_nextframe = traj_extMs[idx_ref_]
                cam_pose_nextframe = torch.FloatTensor(
                    warp_homo.get_rel_extrinsicM(traj_extMs[ref_indx],
                                                 cam_pose_nextframe)).cuda()

                # get depth and confidence map #
                BV_tmp_ = warp_homo.resample_vol_cuda(\
                                        src_vol = BVs_measure, rel_extM = cam_pose_nextframe.inverse(),
                                        cam_intrinsic = dataset_imgsize.cam_intrinsics,
                                        d_candi = d_candi, d_candi_new = d_candi,
                                        padding_value = math.log(1. / float(len(d_candi)))
                                        ).clamp(max=0, min=-1000.)
                dmap_ref = m_misc.depth_val_regression(BVs_measure,
                                                       d_candi,
                                                       BV_log=True).squeeze()
                conf_map_ref, _ = torch.max(BVs_measure.squeeze(), dim=0)
                dmap_kf = m_misc.depth_val_regression(BV_tmp_.unsqueeze(0),
                                                      d_candi,
                                                      BV_log=True).squeeze()
                conf_map_kf, _ = torch.max(BV_tmp_.squeeze(), dim=0)

                # setup optimization #
                cams_intrin = [
                    dataset.cam_intrinsics, dataset_Himgsize.cam_intrinsics,
                    dataset_imgsize.cam_intrinsics
                ]
                dw_scales = [4, 2, 1]
                LBA_max_iter = args.LBA_max_iter  #10 # 20
                LBA_step = args.LBA_step  #.05 #.01
                if LBA_max_iter <= 1:  # do not do optimization
                    LBA_step = 0.
                opt_vars = [args.opt_r, args.opt_t]

                # initialization for the first time window #
                if first_frame:
                    first_frame = False

                    # optimize the pose for all frames within the window #
                    if LBA_max_iter <= 1:  # for debugging: using GT pose initialization #
                        rel_pose_inits_all_frame, srcs_idx_all_frame = m_misc.get_twin_rel_pose(
                            traj_extMs,
                            idx_ref_,
                            t_win_r * dat_indx_step,
                            1,
                            use_gt_R=True,
                            use_gt_t=True,
                            dataset=dataset,
                            add_noise_gt=False,
                            noise_sigmas=None)
                    else:
                        rel_pose_inits_all_frame, srcs_idx_all_frame = m_misc.get_twin_rel_pose(
                            traj_extMs,
                            ref_indx,
                            t_win_r * dat_indx_step,
                            1,
                            use_gt_R=False,
                            use_gt_t=False,
                            dataset=dataset,
                        )
                    # opt. #
                    img_ref = dataset[ref_indx]['img']
                    imgs_src = [dataset[i]['img'] for i in srcs_idx_all_frame]
                    conf_map_ref = torch.exp(conf_map_ref).squeeze()**2
                    rel_pose_opt = opt_pose_numerical.local_BA_direct(
                        img_ref,
                        imgs_src,
                        dmap_ref.unsqueeze(0).unsqueeze(0),
                        conf_map_ref.unsqueeze(0).unsqueeze(0),
                        cams_intrin,
                        dw_scales,
                        rel_pose_inits_all_frame,
                        max_iter=LBA_max_iter,
                        step=LBA_step,
                        opt_vars=opt_vars)

                    # update #
                    for idx, srcidx in enumerate(srcs_idx_all_frame):
                        traj_extMs[srcidx] = np.matmul(
                            rel_pose_opt[idx].cpu().numpy(),
                            traj_extMs[ref_indx])

                # for next frame #
                if LBA_max_iter <= 1:  # for debugging: using GT pose init.
                    rel_pose_opt, srcs_idx = m_misc.get_twin_rel_pose(
                        traj_extMs,
                        idx_ref_,
                        t_win_r,
                        dat_indx_step,
                        use_gt_R=True,
                        use_gt_t=True,
                        dataset=dataset,
                        add_noise_gt=False,
                        noise_sigmas=None,
                    )
                else:
                    rel_pose_inits, srcs_idx = m_misc.get_twin_rel_pose(
                        traj_extMs,
                        idx_ref_,
                        t_win_r,
                        dat_indx_step,
                        use_gt_R=args.use_gt_R,
                        use_dso_R=args.use_dso_R,
                        use_gt_t=args.use_gt_t,
                        use_dso_t=args.use_dso_t,
                        dataset=dataset,
                        traj_extMs_dso=traj_extMs_dso,
                        opt_next_frame=args.opt_next_frame)

                    img_ref = dataset[idx_ref_]['img']
                    _, src_dats_opt = m_misc.split_frame_list(
                        dat_array, t_win_r)
                    imgs_src = [dat_['img'] for dat_ in src_dats_opt]
                    img_ref = dataset[idx_ref_]['img']
                    imgs_src = [dataset[i] for i in srcs_idx]
                    imgs_src = [img_['img'] for img_ in imgs_src]

                    # opt. #
                    conf_map_kf = torch.exp(conf_map_kf).squeeze()**2
                    rel_pose_opt = \
                            opt_pose_numerical.local_BA_direct_parallel(
                            img_ref, imgs_src,
                            dmap_kf.unsqueeze(0).unsqueeze(0),
                            conf_map_kf.unsqueeze(0).unsqueeze(0), cams_intrin,
                            dw_scales, rel_pose_inits, max_iter = LBA_max_iter,
                            step = LBA_step, opt_vars = opt_vars)

                # update #
                print('idx_ref_: %d' % (idx_ref_))
                print('srcs_idx : ')
                print(srcs_idx)

                print('updating pose ...')
                for idx, srcidx in enumerate(srcs_idx):
                    traj_extMs[srcidx] = np.matmul(
                        rel_pose_opt[idx].cpu().numpy(), traj_extMs[idx_ref_])
                print('done')

            else:  # if the sequence contains invalid pose estimation
                BVs_predict = None
                print('frame_cnt :%d, include invalid poses' % (frame_cnt))
                # UPDATE dat_array #
                if dat_indx_step > 1:  # use one-interval video and the frame interval is larger than 5
                    print('updating array ...')
                    dat_array = update_dat_array(dat_array,
                                                 dataset,
                                                 data_interv=1,
                                                 frame_interv=5,
                                                 ref_indx=ref_indx,
                                                 t_win_r=t_win_r)
                    print('done')

                else:
                    dat_array.pop(0)
                    new_dat = dataset[ref_indx + t_win_r + 1]
                    dat_array.append(new_dat)
        m_misc.save_ScenePathInfo('%s/scene_path_info.txt' % (res_fldr),
                                  scene_path_info)
Ejemplo n.º 8
0
def main():
    import argparse
    print('Parsing the arguments...')
    parser = argparse.ArgumentParser()

    # exp name #
    parser.add_argument(
        '--exp_name',
        required=True,
        type=str,
        help='The name of the experiment. Used to naming the folders')

    # nepoch #
    parser.add_argument('--nepoch',
                        required=True,
                        type=int,
                        help='# of epochs to run')

    # if pretrain #
    parser.add_argument('--pre_trained',
                        action='store_true',
                        default=False,
                        help='If use the pre-trained model; (False)')

    # logging #
    parser.add_argument('--TB_add_img_interv',
                        type=int,
                        default=50,
                        help='The inerval for log one training image')

    parser.add_argument('--pre_trained_model_path',
                        type=str,
                        default='.',
                        help='The pre-trained model path for\
                        KV-net')

    # model saving #
    parser.add_argument(
        '--save_model_interv',
        type=int,
        default=5000,
        help='The interval of iters to save the model; default: 5000')

    # tensorboard #
    parser.add_argument(
        '--TB_fldr',
        type=str,
        default='runs',
        help='The tensorboard logging root folder; default: runs')

    # about training #
    parser.add_argument(
        '--RNet',
        action='store_true',
        help='if use refinement net to improve the depth resolution',
        default=True)

    parser.add_argument('--weight_var',
                        default=.001,
                        type=float,
                        help='weight for the variance loss, if we use L1 loss')

    parser.add_argument(
        '--pose_noise_level',
        default=0,
        type=float,
        help='Noise level for pose. Used for training with pose noise')

    parser.add_argument('--frame_interv',
                        default=5,
                        type=int,
                        help='frame interval')

    parser.add_argument('--LR', default=.001, type=float, help='Learning rate')

    parser.add_argument('--t_win',
                        type=int,
                        default=2,
                        help='The radius of the temporal window; default=2')

    parser.add_argument('--d_min',
                        type=float,
                        default=0,
                        help='The minimal depth value; default=0')

    parser.add_argument('--d_max',
                        type=float,
                        default=15,
                        help='The maximal depth value; default=15')

    parser.add_argument('--ndepth',
                        type=int,
                        default=128,
                        help='The # of candidate depth values; default= 128')

    parser.add_argument('--grad_clip',
                        action='store_true',
                        help='if clip the gradient')

    parser.add_argument('--grad_clip_max',
                        type=float,
                        default=2,
                        help='the maximal norm of the gradient')

    parser.add_argument('--sigma_soft_max',
                        type=float,
                        default=500.,
                        help='sigma_soft_max, default = 500.')

    parser.add_argument(
        '--feature_dim',
        type=int,
        default=64,
        help='The feature dimension for the feature extractor; default=64')

    parser.add_argument(
        '--batch_size',
        type=int,
        default=0,
        help='The batch size for training; default=0, means batch_size=nGPU')

    # about dataset #
    parser.add_argument('--dataset',
                        type=str,
                        default='scanNet',
                        help='Dataset name: {scanNet, kitti,}')
    parser.add_argument('--dataset_path',
                        type=str,
                        default='.',
                        help='Path to the dataset')
    parser.add_argument(
        '--change_aspect_ratio',
        action='store_true',
        default=False,
        help=
        'If we want to change the aspect ratio. This option is only useful for KITTI'
    )
    parser.add_argument('--ngpu', type=int, default=1., help='How many GPU')

    # para config. #
    args = parser.parse_args()
    exp_name = args.exp_name
    saved_model_path = './saved_models/%s' % (exp_name)
    dataset_name = args.dataset

    if args.batch_size == 0:
        batch_size = torch.cuda.device_count()
    else:
        batch_size = args.batch_size

    n_epoch = args.nepoch
    TB_add_img_interv = args.TB_add_img_interv
    pre_trained = args.pre_trained
    t_win_r = args.t_win
    nDepth = args.ndepth
    d_candi = np.linspace(args.d_min, args.d_max, nDepth)
    d_candi_up = np.linspace(args.d_min, args.d_max, nDepth * 4)
    LR = args.LR
    sigma_soft_max = args.sigma_soft_max  #10.#500.
    dnet_feature_dim = args.feature_dim
    frame_interv = args.frame_interv  # should be multiple of 5 for scanNet dataset
    if_clip_gradient = args.grad_clip
    grad_clip_max = args.grad_clip_max
    d_candi_dmap_ref = d_candi
    nDepth_dmap_ref = nDepth
    ngpu = args.ngpu

    # saving model config.#
    m_misc.m_makedir(saved_model_path)
    savemodel_interv = args.save_model_interv

    # writer #
    log_dir = '%s/%s' % (args.TB_fldr, exp_name)
    writer = SummaryWriter(log_dir=log_dir, comment='%s' % (exp_name))
    m_misc.save_args(args, '%s/tr_paras.txt' %
                     (log_dir))  # save the training parameters #
    logfile = os.path.join(log_dir, 'log_' + str(time.time()) + '.txt')
    stdout = Logger.Logger(logfile)
    sys.stdout = stdout

    # Initialize data-loader, model and optimizer #

    # ===== Dataset selection ======== #
    dataset_path = args.dataset_path
    if dataset_name == 'scanNet':
        dataset_init = dl_scanNet.ScanNet_dataset

        if not dataset_path == '.':
            fun_get_paths = lambda traj_indx: dl_scanNet.get_paths(
                traj_indx,
                frame_interv=5,
                split_txt='./mdataloader/scanNet_split/scannet_train.txt',
                database_path_base=dataset_path)
        else:
            fun_get_paths = lambda traj_indx: dl_scanNet.get_paths(
                traj_indx,
                frame_interv=5,
                split_txt='./mdataloader/scanNet_split/scannet_train.txt')

        img_size = [384, 256]

        # trajectory index for training #
        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    elif dataset_name == 'kitti':
        import mdataloader.kitti as dl_kitti
        dataset_init = dl_kitti.KITTI_dataset

        if not dataset_path == '.':
            fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
                traj_indx,
                split_txt='./mdataloader/kitti_split/training.txt',
                mode='train',
                database_path_base=dataset_path)
        else:  # use default database path
            fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
                traj_indx,
                split_txt='./mdataloader/kitti_split/training.txt',
                mode='train')


#        img_size = [1248, 380]
        if not args.change_aspect_ratio:  # we will keep the aspect ratio and do cropping
            img_size = [768, 256]
            crop_w = 384

        else:  # we will change the aspect ratio and NOT do cropping
            img_size = [384, 256]
            #            img_size = [512, 256]
            #            img_size = [624, 256]
            crop_w = None

        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    else:
        raise Exception('dataset not implemented ')

    fldr_path, img_paths, dmap_paths, poses, intrin_path = fun_get_paths(0)
    if dataset_name == 'kitti':
        dataset = dataset_init(True,
                               img_paths,
                               dmap_paths,
                               poses,
                               intrin_path=intrin_path,
                               img_size=img_size,
                               digitize=True,
                               d_candi=d_candi_dmap_ref,
                               d_candi_up=d_candi_up,
                               resize_dmap=.25,
                               crop_w=crop_w)
    else:
        raise Exception('dataset not implemented ')
    # ================================ #

    print('Initnializing the KV-Net')
    model_KVnet = m_kvnet.KVNET(feature_dim=dnet_feature_dim,
                                cam_intrinsics=dataset.cam_intrinsics,
                                d_candi=d_candi,
                                sigma_soft_max=sigma_soft_max,
                                KVNet_feature_dim=dnet_feature_dim,
                                d_upsample_ratio_KV_net=None,
                                t_win_r=t_win_r,
                                if_refined=args.RNet)

    model_KVnet = torch.nn.DataParallel(model_KVnet,
                                        device_ids=range(ngpu),
                                        dim=0)
    model_KVnet.cuda(0)

    optimizer_KV = optim.Adam(model_KVnet.parameters(),
                              lr=LR,
                              betas=(.9, .999))

    model_path_KV = args.pre_trained_model_path
    if model_path_KV is not '.' and pre_trained:
        print('loading KV_net at %s' % (model_path_KV))
        utils_model.load_pretrained_model(model_KVnet, model_path_KV,
                                          optimizer_KV)

    print('Done')

    LOSS = []
    total_iter = 0

    d_candi_up = d_candi

    for iepoch in range(n_epoch):
        BatchScheduler = batch_loader.Batch_Loader(batch_size=batch_size,
                                                   fun_get_paths=fun_get_paths,
                                                   dataset_traj=dataset,
                                                   nTraj=len(traj_Indx),
                                                   dataset_name=dataset_name)

        for batch_idx in range(len(BatchScheduler)):
            for frame_count, ref_indx in enumerate(
                    range(BatchScheduler.traj_len)):
                local_info = BatchScheduler.local_info_full()
                n_valid_batch = local_info['is_valid'].sum()

                if n_valid_batch > 0:
                    local_info_valid = batch_loader.get_valid_items(local_info)
                    ref_dats_in = []
                    for m in range(0, len(local_info_valid['ref_dats'])):
                        ref_dats_in.append(
                            local_info_valid['ref_dats'][m]["left_camera"])
                    src_dats_in = []
                    for m in range(0, len(local_info_valid['src_dats'])):
                        orig = []
                        for k in range(0,
                                       len(local_info_valid['src_dats'][m])):
                            orig.append(local_info_valid['src_dats'][m][k]
                                        ["left_camera"])
                        src_dats_in.append(orig)

                    cam_intrin_in = local_info_valid['left_cam_intrins']
                    src_cam_poses_in = torch.cat(
                        local_info_valid['left_src_cam_poses'], dim=0)

                    if args.pose_noise_level > 0:
                        src_cam_poses_in = add_noise2pose(
                            src_cam_poses_in, args.pose_noise_level)

                    if frame_count == 0 or prev_invalid:
                        prev_invalid = False
                        BVs_predict_in = None
                        print('frame_count ==0 or invalid previous frame')
                    else:
                        BVs_predict_in = batch_loader.get_valid_BVs(
                            BVs_predict, local_info['is_valid'])
                        #print("HACK")
                        #BVs_predict_in = None

                    BVs_measure, BVs_predict, loss, dmap_log_l, dmap_log_h = train_KVNet.train(
                        n_valid_batch,
                        model_KVnet,
                        optimizer_KV,
                        t_win_r,
                        d_candi,
                        Ref_Dats=ref_dats_in,
                        Src_Dats=src_dats_in,
                        Src_CamPoses=src_cam_poses_in,
                        BVs_predict=BVs_predict_in,
                        Cam_Intrinsics=cam_intrin_in,
                        weight_var=args.weight_var,
                        loss_type='NLL',
                        mGPU=True)

                    BVs_measure = BVs_measure.detach()
                    loss_v = float(loss.data.cpu().numpy())

                    if n_valid_batch < BatchScheduler.batch_size:
                        BVs_predict = batch_loader.fill_BVs_predict(
                            BVs_predict, local_info['is_valid'])

                else:
                    loss_v = LOSS[-1]
                    prev_invalid = True

                # Update dat_array #
                if frame_count < BatchScheduler.traj_len - 1:
                    BatchScheduler.proceed_frame()

                total_iter += 1

                # logging #
                if frame_count > 0:
                    LOSS.append(loss_v)
                    print('video batch %d / %d, iter: %d, frame_count: %d; Epoch: %d / %d, loss = %.5f'\
                          %(batch_idx + 1, len(BatchScheduler), total_iter, frame_count, iepoch + 1, n_epoch, loss_v))

                    writer.add_scalar('data/train_error', float(loss_v),
                                      total_iter)

                if total_iter % savemodel_interv == 0:
                    # if training, save the model #
                    savefilename = saved_model_path + '/kvnet_checkpoint_iter_' + str(
                        total_iter) + '.tar'
                    torch.save(
                        {
                            'iter': total_iter,
                            'frame_count': frame_count,
                            'ref_indx': ref_indx,
                            'traj_idx': batch_idx,
                            'state_dict': model_KVnet.state_dict(),
                            'optimizer': optimizer_KV.state_dict(),
                            'loss': loss_v
                        }, savefilename)

                if total_iter % TB_add_img_interv == 0:
                    # if training, logging #
                    th_dmaps_log = torch.FloatTensor(
                        dmap_log_l.astype(np.float32))
                    th_dmaps_log = th_dmaps_log.unsqueeze(0)
                    th_dmaps_log = (th_dmaps_log /
                                    (d_candi_dmap_ref.max())).clamp(0, 1)
                    th_dmaps_log = th_dmaps_log.repeat([3, 1, 1])
                    input_img_log = ref_dats_in[0]['img'].clone()
                    input_img_log = (input_img_log - input_img_log.min()) / (
                        input_img_log.max() - input_img_log.min())
                    input_img_log = input_img_log.squeeze()

                    # assuming N=1 for BVs_measure #
                    confMap_log, _ = torch.max(BVs_measure[0, ...], dim=0)
                    confMap_log = torch.exp(confMap_log.squeeze().cpu())
                    confMap_log /= confMap_log.max()
                    confMap_log = confMap_log.repeat([3, 1, 1])
                    writer.add_image('%s/tr_dmaps' % (exp_name), th_dmaps_log,
                                     total_iter)
                    writer.add_image('%s/tr_input' % (exp_name), input_img_log,
                                     total_iter)
                    writer.add_image('%s/conf_map' % (exp_name), confMap_log,
                                     total_iter)

                    # up-sample branch #
                    if dmap_log_h is not -1:
                        th_dmaps_up_log = torch.FloatTensor(
                            dmap_log_h.astype(np.float32))
                        th_dmaps_up_log = th_dmaps_up_log.unsqueeze(0)
                        th_dmaps_up_log = (th_dmaps_up_log /
                                           (d_candi_dmap_ref.max())).clamp(
                                               0, 1)
                        th_dmaps_up_log = th_dmaps_up_log.repeat([3, 1, 1])
                        writer.add_image('%s/tr_dmaps_up' % (exp_name),
                                         th_dmaps_up_log, total_iter)

            BatchScheduler.proceed_batch()

    writer.close()
    stdout.delink()
Ejemplo n.º 9
0
def main():
    import argparse
    print('Parsing the arguments...')
    parser = argparse.ArgumentParser()

    # exp name #
    parser.add_argument(
        '--exp_name',
        required=True,
        type=str,
        help='The name of the experiment. Used to naming the folders')

    # about testing #
    parser.add_argument('--model_path',
                        type=str,
                        required=True,
                        help='The pre-trained model path for KV-net')
    parser.add_argument('--split_file',
                        type=str,
                        default=True,
                        help='The split txt file')
    parser.add_argument('--frame_interv',
                        default=5,
                        type=int,
                        help='frame interval')
    parser.add_argument('--t_win',
                        type=int,
                        default=2,
                        help='The radius of the temporal window; default=2')
    parser.add_argument('--d_min',
                        type=float,
                        default=0,
                        help='The minimal depth value; default=0')
    parser.add_argument('--d_max',
                        type=float,
                        default=5,
                        help='The maximal depth value; default=15')
    parser.add_argument('--ndepth',
                        type=int,
                        default=64,
                        help='The # of candidate depth values; default= 128')
    parser.add_argument('--sigma_soft_max',
                        type=float,
                        default=10.,
                        help='sigma_soft_max, default = 500.')
    parser.add_argument(
        '--feature_dim',
        type=int,
        default=64,
        help='The feature dimension for the feature extractor; default=64')

    # about dataset #
    parser.add_argument('--dataset',
                        type=str,
                        default='scanNet',
                        help='Dataset name: {scanNet, 7scenes, kitti}')
    parser.add_argument('--dataset_path',
                        type=str,
                        default='.',
                        help='Path to the dataset')
    parser.add_argument(
        '--change_aspect_ratio',
        type=bool,
        default=False,
        help=
        'If we want to change the aspect ratio. This option is only useful for KITTI'
    )

    # parsing parameters #
    args = parser.parse_args()
    exp_name = args.exp_name
    dataset_name = args.dataset
    t_win_r = args.t_win
    nDepth = args.ndepth
    d_candi = np.linspace(args.d_min, args.d_max, nDepth)
    sigma_soft_max = args.sigma_soft_max  #10.#500.
    dnet_feature_dim = args.feature_dim
    frame_interv = args.frame_interv  # should be multiple of 5 for scanNet dataset
    d_upsample = None
    d_candi_dmap_ref = d_candi
    nDepth_dmap_ref = nDepth
    split_file = args.split_file

    # ===== Dataset selection ======== #
    dataset_path = args.dataset_path

    if dataset_name == 'kitti':
        import mdataloader.kitti as dl_kitti
        dataset_init = dl_kitti.KITTI_dataset
        fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
            traj_indx, split_txt=split_file, mode='val')
        if not dataset_path == '.':
            fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
                traj_indx,
                split_txt=split_file,
                mode='val',
                database_path_base=dataset_path)
        else:  # use default database path
            fun_get_paths = lambda traj_indx: dl_kitti.get_paths(
                traj_indx, split_txt=split_file, mode='val')

        if not args.change_aspect_ratio:  # we will keep the aspect ratio and do cropping
            img_size = [768, 356]
            crop_w = None
        else:  # we will change the aspect ratio and NOT do cropping
            img_size = [768, 356]
            crop_w = None

        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    elif dataset_name == 'dm':
        import mdataloader.dm as dl_dm
        dataset_init = dl_dm.DMdataset
        split_file = './mdataloader/dm_split/dm_split.txt' if args.split_file == '.' else args.split_file
        fun_get_paths = lambda traj_indx: dl_dm.get_paths(
            traj_indx, split_txt=split_file, mode='val')
        if not dataset_path == '.':
            fun_get_paths = lambda traj_indx: dl_dm.get_paths(
                traj_indx,
                split_txt=split_file,
                mode='val',
                database_path_base=dataset_path)
        else:  # use default database path
            fun_get_paths = lambda traj_indx: dl_dm.get_paths(
                traj_indx, split_txt=split_file, mode='val')

        if not args.change_aspect_ratio:  # we will keep the aspect ratio and do cropping
            img_size = [786, 256]
            crop_w = None
        else:  # we will change the aspect ratio and NOT do cropping
            img_size = [786, 256]
            crop_w = None

        n_scenes, _, _, _, _ = fun_get_paths(0)
        traj_Indx = np.arange(0, n_scenes)

    else:
        raise Exception('dataset loader not implemented')

    fldr_path, img_paths, dmap_paths, poses, intrin_path = fun_get_paths(0)
    if dataset_name == 'kitti':
        dataset = dataset_init(True,
                               img_paths,
                               dmap_paths,
                               poses,
                               intrin_path=intrin_path,
                               img_size=img_size,
                               digitize=True,
                               d_candi=d_candi_dmap_ref,
                               resize_dmap=.25,
                               crop_w=crop_w)

        dataset_imgsize = dataset_init(True,
                                       img_paths,
                                       dmap_paths,
                                       poses,
                                       intrin_path=intrin_path,
                                       img_size=img_size,
                                       digitize=True,
                                       d_candi=d_candi_dmap_ref,
                                       resize_dmap=1)
    else:
        dataset = dataset_init(True,
                               img_paths,
                               dmap_paths,
                               poses,
                               intrin_path=intrin_path,
                               img_size=img_size,
                               digitize=True,
                               d_candi=d_candi_dmap_ref,
                               resize_dmap=.25)

        dataset_imgsize = dataset_init(True,
                                       img_paths,
                                       dmap_paths,
                                       poses,
                                       intrin_path=intrin_path,
                                       img_size=img_size,
                                       digitize=True,
                                       d_candi=d_candi_dmap_ref,
                                       resize_dmap=1)
    # ================================ #
    print('Initnializing the KV-Net')
    model_KVnet = m_kvnet.KVNET(feature_dim=dnet_feature_dim,
                                cam_intrinsics=dataset.cam_intrinsics,
                                d_candi=d_candi,
                                sigma_soft_max=sigma_soft_max,
                                KVNet_feature_dim=dnet_feature_dim,
                                d_upsample_ratio_KV_net=d_upsample,
                                t_win_r=t_win_r,
                                if_refined=True)

    model_KVnet = torch.nn.DataParallel(model_KVnet)
    model_KVnet.cuda()

    model_path_KV = args.model_path
    print('loading KV_net at %s' % (model_path_KV))
    utils_model.load_pretrained_model(model_KVnet, model_path_KV)
    print('Done')

    rmse, absrel, lg10, squarel, rmselog, D1, D2, D3 = 0, 0, 0, 0, 0, 0, 0, 0

    for traj_idx in traj_Indx:
        res_fldr = '../results/%s/traj_%d' % (exp_name, traj_idx)
        m_misc.m_makedir(res_fldr)
        scene_path_info = []

        print('Getting the paths for traj_%d' % (traj_idx))
        fldr_path, img_seq_paths, dmap_seq_paths, poses, intrin_path = fun_get_paths(
            traj_idx)
        dataset.set_paths(img_seq_paths, dmap_seq_paths, poses)

        if dataset_name is 'scanNet':
            # For each trajector in the dataset, we will update the intrinsic matrix #
            dataset.get_cam_intrinsics(intrin_path)

        print('Done')
        dat_array = [dataset[idx] for idx in range(t_win_r * 2 + 1)]
        DMaps_meas = []
        traj_length = len(dataset)
        print('trajectory length = %d' % (traj_length))

        average_meter = export_res.AverageMeter()

        ### inference time
        torch.cuda.synchronize()
        start = time.time()

        for frame_cnt, ref_indx in enumerate(
                range(t_win_r, traj_length - t_win_r - 1)):
            result = export_res.Result()

            torch.cuda.synchronize()
            data_time = time.time() - start

            eff_iter = True
            valid_seq = check_datArray_pose(dat_array)

            # Read ref. and src. data in the local time window #
            ref_dat, src_dats = m_misc.split_frame_list(dat_array, t_win_r)

            if frame_cnt == 0:
                BVs_predict = None

            if valid_seq and eff_iter:
                # Get poses #
                src_cam_extMs = m_misc.get_entries_list_dict(src_dats, 'extM')
                src_cam_poses = \
                        [warp_homo.get_rel_extrinsicM(ref_dat['extM'], src_cam_extM_) \
                        for src_cam_extM_ in src_cam_extMs ]

                src_cam_poses = [
                    torch.from_numpy(pose.astype(
                        np.float32)).cuda().unsqueeze(0)
                    for pose in src_cam_poses
                ]

                # src_cam_poses size: N V 4 4 #
                src_cam_poses = torch.cat(src_cam_poses, dim=0).unsqueeze(0)
                src_frames = [m_misc.get_entries_list_dict(src_dats, 'img')]

                if frame_cnt == 0 or BVs_predict is None:  # the first window for the traj.
                    BVs_predict_in = None
                else:
                    BVs_predict_in = BVs_predict

                # print('testing on %d/%d frame in traj %d/%d ... '%\
                #        (frame_cnt+1, traj_length - 2*t_win_r, traj_idx+1, len(traj_Indx)) )

                torch.cuda.synchronize()
                gpu_time = time.time() - start

                # set trace for specific frame #
                BVs_measure, BVs_predict = test_KVNet.test(
                    model_KVnet,
                    d_candi,
                    Ref_Dats=[ref_dat],
                    Src_Dats=[src_dats],
                    Cam_Intrinsics=[dataset.cam_intrinsics],
                    t_win_r=t_win_r,
                    Src_CamPoses=src_cam_poses,
                    BV_predict=BVs_predict_in,
                    R_net=True,
                    Cam_Intrinsics_imgsize=dataset_imgsize.cam_intrinsics,
                    ref_indx=ref_indx)

                pred_depth, gt = export_res.do_evaluation(
                    ref_dat, BVs_measure, d_candi_dmap_ref)

                # print(pred_depth.shape, gt.shape)

                result.evaluate(pred_depth.data, gt.data)

                average_meter.update(result, gpu_time, data_time,
                                     (traj_length - 2 * t_win_r))

                scene_path_info.append(
                    [frame_cnt, dataset[ref_indx]['img_path']])

            elif valid_seq is False:  # if the sequence contains invalid pose estimation
                BVs_predict = None
                print('frame_cnt :%d, include invalid poses' % (frame_cnt))

            elif eff_iter is False:
                BVs_predict = None

            # Update dat_array #
            dat_array.pop(0)
            dat_array.append(dataset[ref_indx + t_win_r + 1])

        avg = average_meter.average()

        print('\n*\n'
              'RMSE={average.rmse:.3f}\n'
              'AbsRel={average.absrel:.3f}\n'
              'Log10={average.lg10:.3f}\n'
              'SquaRel={average.squarel:.3f}\n'
              'rmselog={average.rmselog:.3f}\n'
              'Delta1={average.delta1:.3f}\n'
              'Delta2={average.delta2:.3f}\n'
              'Delta3={average.delta3:.3f}\n'
              't_GPU={time:.3f}\n'.format(average=avg, time=avg.gpu_time))

        ### inference time
        torch.cuda.synchronize()
        end = time.time()

        rmse += avg.rmse
        absrel += avg.absrel
        lg10 += avg.lg10
        squarel += avg.squarel
        rmselog += avg.rmselog
        D1 += avg.delta1
        D2 += avg.delta2
        D3 += avg.delta3

        print('rmse={%.3f}\n' % (rmse / (traj_idx + 1)),
              'absrel={%.3f}\n' % (absrel / (traj_idx + 1)),
              'lg10={%.3f}\n' % (lg10 / (traj_idx + 1)),
              'squarel={%.3f}\n' % (squarel / (traj_idx + 1)),
              'rmselog={%.3f}\n' % (rmselog / (traj_idx + 1)),
              'D1={%.3f}\n' % (D1 / (traj_idx + 1)),
              'D2={%.3f}\n' % (D2 / (traj_idx + 1)),
              'D3={%.3f}\n' % (D3 / (traj_idx + 1)))

        print((end - start) / (traj_length - 2 * t_win_r))

        m_misc.save_ScenePathInfo('%s/scene_path_info.txt' % (res_fldr),
                                  scene_path_info)