def do_evaluation(ref_dat, BV_measure, d_candi): dmap = m_misc.depth_val_regression( BV_measure, d_candi, BV_log=True).squeeze().cpu().numpy() # (256, 768) gt = ref_dat['dmap_rawsize'] raw_w, raw_h = gt.shape[1], gt.shape[2] dmap = image.fromarray(dmap) pred_depth = dmap.resize((raw_h, raw_w), image.NEAREST) pred_depth = torch.FloatTensor(np.array(pred_depth)).unsqueeze(0) return pred_depth, gt
def train(nGPU, model_KV, optimizer_KV, t_win_r, d_candi, Ref_Dats, Src_Dats, Src_CamPoses, BVs_predict, Cam_Intrinsics, refine_dup=False, weight_var=.001, loss_type='NLL', mGPU=False, Cam_Intrinsics_spatial_up=None, return_confmap_up=False): r''' Perform one single iteration for the training Support multiple GPU traning. To do this we treat each trajector as one batch Inputs: model_KV - optimizer_KV - Ref_Dats - list of ref_dat Src_Dats - list of list of src_dat: [ list_src_dats_traj_0, ...] list_src_dats_traj0[iframe] : NCHW Src_CamPoses - N x V x 4 x 4, where N: batch size (# of batched traj), V: # of src. views BVs_predict - N x D x H_feat x W_feat Cam_Intrinsics - list of camera intrinsics for the batched trajectories refine_dup - if upsample the depth dimension in the refinement net loss_type = {'L1', 'NLL'} L1 - we will calculate the mean value from low res. DPV and filter it with DGF to get the L1 loss in high res.; In additional to that, we will also calculate the variance loss NLL - we will calulate the NLL loss from the low res. DPV Outputs: ''' # Make this work with bs 4 first # Then make it work with just image input 3 # Then change it to the new dataloader left right # then change it to the multi process thing # prepare for the inputs # ref_frame = torch.cat(tuple([ref_dat['img'] for ref_dat in Ref_Dats]), dim=0) src_frames_list = [torch.cat(tuple([src_dat_frame['img'] \ for src_dat_frame in src_dats_traj]), dim=0).unsqueeze(0) \ for src_dats_traj in Src_Dats] src_frames = torch.cat(tuple(src_frames_list), dim=0) optimizer_KV.zero_grad() # If upsample d in the refinement net# if refine_dup: dup4_candi = np.linspace(0, d_candi.max(), 4 * len(d_candi)) # kv-net Forward pass # # model_KV supports multiple-gpus # BatchIdx_range = torch.FloatTensor(np.arange(nGPU)) IntMs = torch.cat([ cam_intrin['intrinsic_M_cuda'].unsqueeze(0) for cam_intrin in Cam_Intrinsics ], dim=0) unit_ray_Ms_2D = torch.cat([ cam_intrin['unit_ray_array_2D'].unsqueeze(0) for cam_intrin in Cam_Intrinsics ], dim=0) bsize = src_frames.shape[0] dmap_cur_refined, dmap_refined, d_dpv, kv_dpv = model_KV( ref_frame=ref_frame.cuda(0), src_frames=src_frames.cuda(0), src_cam_poses=Src_CamPoses.cuda(0), BatchIdx=BatchIdx_range.cuda(0), cam_intrinsics=None, BV_predict=BVs_predict, mGPU=mGPU, IntMs=IntMs.cuda(0), unit_ray_Ms_2D=unit_ray_Ms_2D.cuda(0)) # Get losses # loss = 0. for ibatch in range(d_dpv.shape[0]): if loss_type is 'NLL': # nll loss (d-net) # depth_ref = Ref_Dats[ibatch]['dmap'].cuda(kv_dpv.get_device()) if refine_dup: depth_ref_imgsize = Ref_Dats[ibatch][ 'dmap_up4_imgsize_digit'].cuda(kv_dpv.get_device()) else: depth_ref_imgsize = Ref_Dats[ibatch][ 'dmap_imgsize_digit'].cuda(kv_dpv.get_device()) loss = loss + F.nll_loss( d_dpv[ibatch, :, :, :].unsqueeze(0), depth_ref, ignore_index=0) loss = loss + F.nll_loss( dmap_cur_refined[ibatch, :, :, :].unsqueeze(0), depth_ref_imgsize, ignore_index=0) if BVs_predict is not None: if m_misc.valid_dpv(BVs_predict[ibatch, ...]): # refined loss = loss + F.nll_loss( kv_dpv[ibatch, :, :, :].unsqueeze(0), depth_ref, ignore_index=0) loss = loss + F.nll_loss( dmap_refined[ibatch, :, :, :].unsqueeze(0), depth_ref_imgsize, ignore_index=0) dmap_kv_lowres = m_misc.depth_val_regression( kv_dpv[0, ...].unsqueeze(0), d_candi, BV_log=True) elif loss_type is 'L1': if mGPU: raise Exception('not implemented for multiple GPUs') # L1 loss # depth_ref = Ref_Dats[ibatch]['dmap_imgsize'].cuda().unsqueeze(0) l1_loss_mask = depth_ref > 0. l1_loss_mask = l1_loss_mask.type_as(depth_ref) loss_BV_cur_L1 = \ F.l1_loss( dmap_cur_refined* l1_loss_mask, depth_ref.cuda().squeeze(1) * l1_loss_mask) if m_misc.valid_dpv(BVs_predict[ibatch, ...]): loss_KV_L1 = F.l1_loss( dmap_refined * l1_loss_mask, depth_ref.cuda().squeeze(1) * l1_loss_mask) # variance # dmap_d_lowres = m_misc.depth_val_regression(d_dpv, d_candi, BV_log=True) loss_BV_cur_var = torch.mean( m_misc.depth_var(d_dpv, dmap_d_lowres, d_candi)) if m_misc.valid_dpv(BVs_predict[ibatch, ...]): dmap_kv_lowres = m_misc.depth_val_regression(kv_dpv, d_candi, BV_log=True) loss_KV_var = torch.mean( m_misc.depth_var(kv_dpv, dmap_kv_lowres, d_candi)) loss = loss_BV_cur_L1 + loss_KV_L1 + weight_var * ( loss_KV_var + loss_BV_cur_var) else: loss = loss_BV_cur_L1 + weight_var * loss_BV_cur_var dmap_kv_lowres = dmap_d_lowres # Backward pass # if mGPU: loss = loss / torch.tensor(float(bsize)).cuda(loss.get_device()) loss.backward() optimizer_KV.step() # BV_predict estimation (3D re-sampling) # d_dpv = d_dpv.detach() kv_dpv = kv_dpv.detach() r_dpv = dmap_cur_refined.detach( ) if dmap_cur_refined is not -1 else dmap_refined.detach() BVs_predict_out = [] for ibatch in range(d_dpv.shape[0]): rel_Rt = Src_CamPoses[ibatch, t_win_r, :, :].inverse() BV_predict = warp_homo.resample_vol_cuda(src_vol = kv_dpv[ibatch, ...].unsqueeze(0), rel_extM = rel_Rt.cuda(kv_dpv.get_device()), cam_intrinsic = Cam_Intrinsics[ibatch], d_candi = d_candi, padding_value = math.log(1. / float(len(d_candi))) \ ).clamp(max=0, min=-1000.).unsqueeze(0) BVs_predict_out.append(BV_predict) BVs_predict_out = torch.cat(BVs_predict_out, dim=0) # logging (for single GPU) # depth_ref_lowres = Ref_Dats[0]['dmap_raw'].cpu().squeeze().numpy() depth_kv_lres_log = dmap_kv_lowres[0, ...].detach().cpu().squeeze().numpy() dmap_log_lres = np.hstack([depth_kv_lres_log, depth_ref_lowres]) if dmap_refined.dim() < 4: # refined depth map depth_kv_hres_log = dmap_refined.detach().cpu().squeeze().numpy() depth_ref_highres = depth_ref.detach().cpu().squeeze().numpy() else: # refined dpv if refine_dup: depth_kv_hres_log = m_misc.depth_val_regression( dmap_refined[0, ...].unsqueeze(0), dup4_candi, BV_log=True).detach().cpu().squeeze().numpy() else: depth_kv_hres_log = m_misc.depth_val_regression( dmap_refined[0, ...].unsqueeze(0), d_candi, BV_log=True).detach().cpu().squeeze().numpy() depth_ref_imgsize_raw = Ref_Dats[0]['dmap_imgsize'].squeeze().cpu( ).numpy() dmap_log_hres = np.hstack([depth_kv_hres_log, depth_ref_imgsize_raw]) if return_confmap_up: confmap_up = torch.exp(dmap_refined[0, ...].detach()) confmap_up, _ = torch.max(confmap_up, dim=0) return r_dpv, BVs_predict_out, loss, dmap_log_lres, dmap_log_hres, confmap_up.cpu( ) else: return r_dpv, BVs_predict_out, loss, dmap_log_lres, dmap_log_hres
def export_res_refineNet(ref_dat, BV_measure, d_candi, res_fldr, batch_idx, diff_vrange_ratio=4, cam_pose = None, cam_intrinM = None, output_pngs = False, save_mat=True, output_dmap_ref=True): ''' export results ''' # img_in # img_up = ref_dat['img'] img_in_raw = img_up.squeeze().cpu().permute(1,2,0).numpy() img_in = (img_in_raw - img_in_raw.min()) / (img_in_raw.max()-img_in_raw.min()) * 255. # confMap # confMap_log, _ = torch.max(BV_measure, dim=1) confMap_log = torch.exp(confMap_log.squeeze().cpu()) confMap_log = confMap_log.cpu().numpy() # depth map # nDepth = len(d_candi) dmap_height, dmap_width = BV_measure.shape[2], BV_measure.shape[3] dmap = m_misc.depth_val_regression(BV_measure, d_candi, BV_log = True).squeeze().cpu().numpy() # save up-sampeled results # resfldr = res_fldr m_misc.m_makedir(resfldr) img_up_path ='%s/input.png'%(resfldr,) conf_up_path = '%s/conf.png'%(resfldr,) dmap_raw_path = '%s/dmap_raw.png'%(resfldr,) final_res_up = '%s/res_%05d.png'%(resfldr, batch_idx) if output_dmap_ref: # output GT depth ref_up = '%s/dmap_ref.png'%(resfldr,) res_up_diff = '%s/dmaps_diff.png'%(resfldr,) dmap_ref = ref_dat['dmap_imgsize'] dmap_ref = dmap_ref.squeeze().cpu().numpy() mask_dmap = (dmap_ref > 0 ).astype(np.float) dmap_diff_raw = np.abs(dmap_ref - dmap ) * mask_dmap dmaps_diff = dmap_diff_raw plt.imsave(res_up_diff, dmaps_diff, vmin=0, vmax=d_candi.max()/ diff_vrange_ratio ) plt.imsave(ref_up, dmap_ref, vmax= d_candi.max(), vmin=0, cmap='gray') plt.imsave(conf_up_path, confMap_log, vmin=0, vmax=1, cmap='jet') plt.imsave(dmap_raw_path, dmap, vmin=0., vmax =d_candi.max(), cmap='gray' ) plt.imsave(img_up_path, img_in.astype(np.uint8)) # output the depth as .mat files # fname_mat = '%s/depth_%05d.mat'%(resfldr, batch_idx) img_path = ref_dat['img_path'] if save_mat: if not output_dmap_ref: mdict = { 'dmap': dmap, 'img': img_in_raw, 'confMap': confMap_log, 'img_path': img_path} elif cam_pose is None: mdict = {'dmap_ref': dmap_ref, 'dmap': dmap, 'img': img_in_raw, 'confMap': confMap_log, 'img_path': img_path} else: mdict = {'dmap_ref': dmap_ref, 'dmap': dmap, 'img': img_in_raw, 'cam_pose': cam_pose, 'confMap':confMap_log, 'cam_intrinM': cam_intrinM, 'img_path': img_path } sio.savemat(fname_mat, mdict) print('export to %s'%(final_res_up)) if output_dmap_ref: cat_imgs((img_up_path, conf_up_path, dmap_raw_path, res_up_diff, ref_up), final_res_up) else: cat_imgs((img_up_path, conf_up_path, dmap_raw_path), final_res_up) if output_pngs: import cv2 png_fldr = '%s/output_pngs'%(res_fldr, ) m_misc.m_makedir( png_fldr ) depth_png = (dmap * 1000 ).astype(np.uint16) img_in_png = _un_normalize( img_in_raw ); img_in_png = (img_in_png * 255).astype(np.uint8) confMap_png = (confMap_log*255).astype(np.uint8) cv2.imwrite( '%s/d_%05d.png'%(png_fldr, batch_idx), depth_png) cv2.imwrite( '%s/rgb_%05d.png'%(png_fldr, batch_idx), img_in_png) cv2.imwrite( '%s/conf_%05d.png'%(png_fldr, batch_idx), confMap_png) if output_dmap_ref: depth_ref_png = (dmap_ref * 1000).astype(np.uint16) cv2.imwrite( '%s/dref_%05d.png'%(png_fldr, batch_idx), depth_ref_png)
def main(): import argparse print('Parsing the arguments...') parser = argparse.ArgumentParser() # exp name # parser.add_argument( '--exp_name', required=True, type=str, help='The name of the experiment. Used to naming the folders') # about testing # parser.add_argument('--img_name_pattern', type=str, default='*.png', help='image name pattern') parser.add_argument('--model_path', type=str, default='.', help='The pre-trained model path for KV-net') parser.add_argument('--split_file', type=str, default='.', help='The split txt file') parser.add_argument('--t_win', type=int, default=2, help='The radius of the temporal window; default=2') parser.add_argument('--d_min', type=float, default=0, help='The minimal depth value; default=0') parser.add_argument('--d_max', type=float, default=5, help='The maximal depth value; default=15') parser.add_argument('--ndepth', type=int, default=64, help='The # of candidate depth values; default= 128') parser.add_argument('--sigma_soft_max', type=float, default=10., help='sigma_soft_max, default = 500.') parser.add_argument( '--feature_dim', type=int, default=64, help='The feature dimension for the feature extractor; default=64') # about pose # parser.add_argument('--intrin_path', type=str, required=True, help='camera intrinic path, saved as .mat') parser.add_argument( '--dso_res_path', type=str, default='dso_res/result_dso.txt', help= 'if use DSO pose, specify the path to the DSO results. Should be a .txt file' ) parser.add_argument('--opt_next_frame', action='store_true', help='') parser.add_argument('--use_gt_R', action='store_true', help='') parser.add_argument('--use_gt_t', action='store_true', help='') parser.add_argument('--use_dso_R', action='store_true', help='') parser.add_argument('--use_dso_t', action='store_true', help='') parser.add_argument('--min_frame_idx', type=int, help=' ', default=0) parser.add_argument('--max_frame_idx', type=int, help=' ', default=10000) parser.add_argument('--refresh_frames', type=int, help=' ', default=1000) parser.add_argument('--LBA_max_iter', type=int, help=' ') parser.add_argument('--opt_r', type=int, default=1, help=' ') parser.add_argument('--opt_t', type=int, default=1, help=' ') parser.add_argument('--LBA_step', type=float, help=' ') parser.add_argument('--frame_interv', type=int, default=5, help=' ') # about dataset # parser.add_argument('--dataset', type=str, default='7scenes', help='Dataset name: {scanNet, 7scenes}') parser.add_argument('--dataset_path', type=str, default='.', help='Path to the dataset') # about output # parser.add_argument('--output_pngs', action='store_true', help='if output pngs') # para config. # args = parser.parse_args() exp_name = args.exp_name dataset_name = args.dataset t_win_r = args.t_win nDepth = args.ndepth d_candi = np.linspace(args.d_min, args.d_max, nDepth) sigma_soft_max = args.sigma_soft_max #10.#500. dnet_feature_dim = args.feature_dim frame_interv = args.frame_interv d_candi_dmap_ref = d_candi nDepth_dmap_ref = nDepth # Initialize data-loader, model and optimizer # # ===== Dataset selection ======== # dataset_path = args.dataset_path if dataset_name == 'scanNet': # deal with 1-frame scanNet data import mdataloader.scanNet as dl_scanNet dataset_init = dl_scanNet.ScanNet_dataset split_txt = './mdataloader/scanNet_split/scannet_val.txt' if args.split_file == '.' else args.split_file if not dataset_path == '.': # if specify the path, we will assume we are using 1-frame-interval scanNet video # fun_get_paths = lambda traj_indx: dl_scanNet.get_paths_1frame( traj_indx, database_path_base=dataset_path, split_txt=split_txt) dat_indx_step = 5 #pick this value to make sure the camera baseline is big enough else: fun_get_paths = lambda traj_indx: dl_scanNet.get_paths( traj_indx, frame_interv=5, split_txt=split_txt) dat_indx_step = 1 img_size = [384, 256] # trajectory index for training # n_scenes, _, _, _, _ = fun_get_paths(0) traj_Indx = np.arange(0, n_scenes) elif dataset_name == '7scenes': # 7 scenes video # import mdataloader.dl_7scenes as dl_7scenes img_size = [384, 256] dataset_init = dl_7scenes.SevenScenesDataset dat_indx_step = 5 # pick this value to make sure the camera baseline is big enough # trajectory index for training # split_file = None if args.split_file == '.' else args.split_file fun_get_paths = lambda traj_indx: dl_7scenes.get_paths_1frame( traj_indx, database_path_base=dataset_path, split_txt=split_file, ) elif dataset_name == 'single_folder': # images in a single folder specified by the user # import mdataloader.mdata as mdata img_size = [384, 256] dataset_init = mdata.mData dat_indx_step = 5 # pick this value to make sure the camera baseline is big enough fun_get_paths = lambda traj_indx: mdata.get_paths_1frame( traj_indx, dataset_path, args.img_name_pattern) traj_Indx = [0] #dummy fldr_path, img_paths, dmap_paths, poses, intrin_path = fun_get_paths( traj_Indx[0]) if dataset_name == 'single_folder': intrin_path = args.intrin_path dataset = dataset_init( True, img_paths, dmap_paths, poses, intrin_path=intrin_path, img_size=img_size, digitize=True, d_candi=d_candi_dmap_ref, resize_dmap=.25, ) dataset_Himgsize = dataset_init( True, img_paths, dmap_paths, poses, intrin_path=intrin_path, img_size=img_size, digitize=True, d_candi=d_candi_dmap_ref, resize_dmap=.5, ) dataset_imgsize = dataset_init( True, img_paths, dmap_paths, poses, intrin_path=intrin_path, img_size=img_size, digitize=True, d_candi=d_candi_dmap_ref, resize_dmap=1, ) # ================================ # print('Initnializing the KV-Net') model_KVnet = m_kvnet.KVNET(\ feature_dim = dnet_feature_dim, cam_intrinsics = dataset.cam_intrinsics, d_candi = d_candi, sigma_soft_max = sigma_soft_max, KVNet_feature_dim = dnet_feature_dim, d_upsample_ratio_KV_net = None, t_win_r = t_win_r, if_refined = True) model_KVnet = torch.nn.DataParallel(model_KVnet) model_KVnet.cuda() model_path_KV = args.model_path print('loading KV_net at %s' % (model_path_KV)) utils_model.load_pretrained_model(model_KVnet, model_path_KV) print('Done') for traj_idx in traj_Indx: scene_path_info = [] print('Getting the paths for traj_%d' % (traj_idx)) fldr_path, img_seq_paths, dmap_seq_paths, poses, intrin_path = fun_get_paths( traj_idx) res_fldr = '../results/%s/traj_%d' % (exp_name, traj_idx) m_misc.m_makedir(res_fldr) dataset.set_paths(img_seq_paths, dmap_seq_paths, poses) if dataset_name == 'scanNet': # the camera intrinsic may be slightly different for different trajectories in scanNet # dataset.get_cam_intrinsics(intrin_path) print('Done') if args.min_frame_idx > 0: frame_idxs = np.arange(args.min_frame_idx - t_win_r, args.min_frame_idx + t_win_r) dat_array = [dataset[idx] for idx in frame_idxs] else: dat_array = [dataset[idx] for idx in range(t_win_r * 2 + 1)] DMaps_meas = [] dso_res_path = args.dso_res_path print('init initial pose from DSO estimations ...') traj_extMs = init_traj_extMs(traj_len=len(dataset), dso_res_path=dso_res_path, if_filter=True, min_idx=args.min_frame_idx, max_idx=args.max_frame_idx) traj_extMs_init = copy_list(traj_extMs) traj_length = min(len(dataset), len(traj_extMs)) first_frame = True for frame_cnt, ref_indx in enumerate( range(t_win_r * dat_indx_step + args.min_frame_idx, traj_length - t_win_r * dat_indx_step - dat_indx_step)): # ref_indx: the frame index for the reference frame # # Read ref. and src. data in the local time window # ref_dat, src_dats = m_misc.split_frame_list(dat_array, t_win_r) src_frame_idx = [ idx for idx in range( ref_indx - t_win_r * dat_indx_step, ref_indx, dat_indx_step) ] + \ [ idx for idx in range( ref_indx + dat_indx_step, ref_indx + t_win_r*dat_indx_step+1, dat_indx_step) ] valid_seq = dso_io.valid_poses(traj_extMs, src_frame_idx) # only look at a subset of frames # if ref_indx < args.min_frame_idx: valid_seq = False if ref_indx > args.max_frame_idx or ref_indx >= traj_length - t_win_r * dat_indx_step - dat_indx_step: break if frame_cnt == 0 or valid_seq is False: BVs_predict = None # refresh # if ref_indx % args.refresh_frames == 0: print('REFRESH !') BVs_predict = None BVs_predict_in = None first_frame = True traj_extMs = copy_list(traj_extMs_init) if valid_seq: # if the sequence does not contain invalid pose estimation # Get poses # src_cam_extMs = [traj_extMs[i] for i in src_frame_idx] ref_cam_extM = traj_extMs[ref_indx] src_cam_poses = [ warp_homo.get_rel_extrinsicM(ref_cam_extM, src_cam_extM_) for src_cam_extM_ in src_cam_extMs ] src_cam_poses = [ torch.from_numpy(pose.astype( np.float32)).cuda().unsqueeze(0) for pose in src_cam_poses ] # Load the gt pose if available # if 'extM' in dataset[0]: src_cam_extMs_ref = [ dataset[i]['extM'] for i in src_frame_idx ] ref_cam_extM_ref = dataset[ref_indx]['extM'] src_cam_poses_ref = [ warp_homo.get_rel_extrinsicM(ref_cam_extM_ref, src_cam_extM_) \ for src_cam_extM_ in src_cam_extMs_ref ] src_cam_poses_ref = [ torch.from_numpy(pose.astype(np.float32)).cuda().unsqueeze(0) \ for pose in src_cam_poses_ref ] # -- Determine the scale, mapping from DSO scale to our working scale -- # if frame_cnt == 0 or BVs_predict is None: # the first window for the traj. _, t_norm_single = get_fb(src_cam_poses, dataset.cam_intrinsics, src_cam_pose_next=None) # We need to heurisitcally determine scale_ without using GT pose # t_norms = get_t_norms(traj_extMs, dat_indx_step) scale_ = d_candi.max() / ( dataset.cam_intrinsics['focal_length'] * np.array(t_norm_single).max() / 2) scale_ = d_candi.max() / ( dataset.cam_intrinsics['focal_length'] * np.array(t_norms).max()) scale_ = d_candi.max() / ( dataset.cam_intrinsics['focal_length'] * np.array(t_norms).mean() / 2) rescale_traj_t(traj_extMs, scale_) traj_extMs_dso = copy_list(traj_extMs) # Get poses # src_cam_extMs = [traj_extMs[i] for i in src_frame_idx] ref_cam_extM = traj_extMs[ref_indx] src_cam_poses = [ warp_homo.get_rel_extrinsicM(ref_cam_extM, src_cam_extM_) for src_cam_extM_ in src_cam_extMs ] src_cam_poses = [ torch.from_numpy(pose.astype( np.float32)).cuda().unsqueeze(0) for pose in src_cam_poses ] # src_cam_poses size: N V 4 4 # src_cam_poses = torch.cat(src_cam_poses, dim=0).unsqueeze(0) src_frames = [m_misc.get_entries_list_dict(src_dats, 'img')] cam_pose_next = traj_extMs[ref_indx + 1] cam_pose_next = torch.FloatTensor( warp_homo.get_rel_extrinsicM(traj_extMs[ref_indx], cam_pose_next)).cuda() BVs_predict_in = None if frame_cnt == 0 or BVs_predict is None \ else BVs_predict BVs_measure, BVs_predict = test_KVNet.test( model_KVnet, d_candi, Ref_Dats=[ref_dat], Src_Dats=[src_dats], Cam_Intrinsics=[dataset.cam_intrinsics], t_win_r=t_win_r, Src_CamPoses=src_cam_poses, BV_predict=BVs_predict_in, R_net=True, cam_pose_next=cam_pose_next, ref_indx=ref_indx) # export_res.export_res_refineNet(ref_dat, BVs_measure, d_candi_dmap_ref, # res_fldr, ref_indx, # save_mat = True, output_pngs = args.output_pngs, output_dmap_ref=False) export_res.export_res_img(ref_dat, BVs_measure, d_candi_dmap_ref, res_fldr, frame_cnt) scene_path_info.append( [frame_cnt, dataset[ref_indx]['img_path']]) # UPDATE dat_array # if dat_indx_step > 1: # use one-interval video and the frame interval is larger than 5 print('updating array ...') dat_array = update_dat_array(dat_array, dataset, data_interv=1, frame_interv=5, ref_indx=ref_indx, t_win_r=t_win_r) print('done') else: dat_array.pop(0) new_dat = dataset[ref_indx + t_win_r + 1] dat_array.append(new_dat) # OPTMIZE POSES # idx_ref_ = ref_indx + 1 cam_pose_nextframe = traj_extMs[idx_ref_] cam_pose_nextframe = torch.FloatTensor( warp_homo.get_rel_extrinsicM(traj_extMs[ref_indx], cam_pose_nextframe)).cuda() # get depth and confidence map # BV_tmp_ = warp_homo.resample_vol_cuda(\ src_vol = BVs_measure, rel_extM = cam_pose_nextframe.inverse(), cam_intrinsic = dataset_imgsize.cam_intrinsics, d_candi = d_candi, d_candi_new = d_candi, padding_value = math.log(1. / float(len(d_candi))) ).clamp(max=0, min=-1000.) dmap_ref = m_misc.depth_val_regression(BVs_measure, d_candi, BV_log=True).squeeze() conf_map_ref, _ = torch.max(BVs_measure.squeeze(), dim=0) dmap_kf = m_misc.depth_val_regression(BV_tmp_.unsqueeze(0), d_candi, BV_log=True).squeeze() conf_map_kf, _ = torch.max(BV_tmp_.squeeze(), dim=0) # setup optimization # cams_intrin = [ dataset.cam_intrinsics, dataset_Himgsize.cam_intrinsics, dataset_imgsize.cam_intrinsics ] dw_scales = [4, 2, 1] LBA_max_iter = args.LBA_max_iter #10 # 20 LBA_step = args.LBA_step #.05 #.01 if LBA_max_iter <= 1: # do not do optimization LBA_step = 0. opt_vars = [args.opt_r, args.opt_t] # initialization for the first time window # if first_frame: first_frame = False # optimize the pose for all frames within the window # if LBA_max_iter <= 1: # for debugging: using GT pose initialization # rel_pose_inits_all_frame, srcs_idx_all_frame = m_misc.get_twin_rel_pose( traj_extMs, idx_ref_, t_win_r * dat_indx_step, 1, use_gt_R=True, use_gt_t=True, dataset=dataset, add_noise_gt=False, noise_sigmas=None) else: rel_pose_inits_all_frame, srcs_idx_all_frame = m_misc.get_twin_rel_pose( traj_extMs, ref_indx, t_win_r * dat_indx_step, 1, use_gt_R=False, use_gt_t=False, dataset=dataset, ) # opt. # img_ref = dataset[ref_indx]['img'] imgs_src = [dataset[i]['img'] for i in srcs_idx_all_frame] conf_map_ref = torch.exp(conf_map_ref).squeeze()**2 rel_pose_opt = opt_pose_numerical.local_BA_direct( img_ref, imgs_src, dmap_ref.unsqueeze(0).unsqueeze(0), conf_map_ref.unsqueeze(0).unsqueeze(0), cams_intrin, dw_scales, rel_pose_inits_all_frame, max_iter=LBA_max_iter, step=LBA_step, opt_vars=opt_vars) # update # for idx, srcidx in enumerate(srcs_idx_all_frame): traj_extMs[srcidx] = np.matmul( rel_pose_opt[idx].cpu().numpy(), traj_extMs[ref_indx]) # for next frame # if LBA_max_iter <= 1: # for debugging: using GT pose init. rel_pose_opt, srcs_idx = m_misc.get_twin_rel_pose( traj_extMs, idx_ref_, t_win_r, dat_indx_step, use_gt_R=True, use_gt_t=True, dataset=dataset, add_noise_gt=False, noise_sigmas=None, ) else: rel_pose_inits, srcs_idx = m_misc.get_twin_rel_pose( traj_extMs, idx_ref_, t_win_r, dat_indx_step, use_gt_R=args.use_gt_R, use_dso_R=args.use_dso_R, use_gt_t=args.use_gt_t, use_dso_t=args.use_dso_t, dataset=dataset, traj_extMs_dso=traj_extMs_dso, opt_next_frame=args.opt_next_frame) img_ref = dataset[idx_ref_]['img'] _, src_dats_opt = m_misc.split_frame_list( dat_array, t_win_r) imgs_src = [dat_['img'] for dat_ in src_dats_opt] img_ref = dataset[idx_ref_]['img'] imgs_src = [dataset[i] for i in srcs_idx] imgs_src = [img_['img'] for img_ in imgs_src] # opt. # conf_map_kf = torch.exp(conf_map_kf).squeeze()**2 rel_pose_opt = \ opt_pose_numerical.local_BA_direct_parallel( img_ref, imgs_src, dmap_kf.unsqueeze(0).unsqueeze(0), conf_map_kf.unsqueeze(0).unsqueeze(0), cams_intrin, dw_scales, rel_pose_inits, max_iter = LBA_max_iter, step = LBA_step, opt_vars = opt_vars) # update # print('idx_ref_: %d' % (idx_ref_)) print('srcs_idx : ') print(srcs_idx) print('updating pose ...') for idx, srcidx in enumerate(srcs_idx): traj_extMs[srcidx] = np.matmul( rel_pose_opt[idx].cpu().numpy(), traj_extMs[idx_ref_]) print('done') else: # if the sequence contains invalid pose estimation BVs_predict = None print('frame_cnt :%d, include invalid poses' % (frame_cnt)) # UPDATE dat_array # if dat_indx_step > 1: # use one-interval video and the frame interval is larger than 5 print('updating array ...') dat_array = update_dat_array(dat_array, dataset, data_interv=1, frame_interv=5, ref_indx=ref_indx, t_win_r=t_win_r) print('done') else: dat_array.pop(0) new_dat = dataset[ref_indx + t_win_r + 1] dat_array.append(new_dat) m_misc.save_ScenePathInfo('%s/scene_path_info.txt' % (res_fldr), scene_path_info)
def forward(self, ref_frame, src_frames, src_cam_poses, BatchIdx, cam_intrinsics=None, BV_predict=None, mGPU=False, IntMs=None, unit_ray_Ms_2D=None): r''' Inputs: ref_frame - NCHW format tensor on GPU, N = 1 src_frames - NVCHW: V - # of source views, N = 1 src_cam_poses - N x V x4 x4 - relative cam poses, N = 1 BatchIdx - e.g. for 4 gpus: [0,1,2,3], used for indexing list input for multi-gpu training cam_intrinsics - list of cam_intrinsics dict. BV_predict - NDHW tensor, the predicted BV, from the last reference frame, N=1 Outputs: dmap_cur_refined, dmap_kv_refined, BV_cur, BV_KV if refined on dpv, then dmap_cur_refined and dmap_kv_refined are refined dpvs NOTE: 1. We should put ref_frame and src_frames and src_cam_poses into GPU before running the forward pass 2. The purpose of enforcing N=1 is for multi-gpu running ''' if isinstance(BV_predict, torch.Tensor): if m_misc.valid_dpv(BV_predict): assert BV_predict.shape[0] == 1 # D-Net # if (self.if_refined is False) or (self.if_refined is True and self.refineNet_name != 'DPV'): BV_cur = self.d_net(ref_frame, src_frames, src_cam_poses, BV_predict=None, debug_ipdb=False) else: BV_cur, d_net_features = self.d_net(ref_frame, src_frames, src_cam_poses, BV_predict=None, debug_ipdb=False) d_net_features.append(ref_frame) if self.if_refined: dmap_cur_lowres = m_misc.depth_val_regression( BV_cur, self.d_candi, BV_log=True).unsqueeze(0) if self.refineNet_name == 'DGF': dmap_cur_refined = self.r_net(dmap_cur_lowres, ref_frame) elif self.refineNet_name == 'DPV': dmap_cur_refined = self.r_net(torch.exp(BV_cur), img_features=d_net_features) else: dmap_cur_refined = -1 if not isinstance(BV_predict, torch.Tensor): #If the first time win., then return only BV_cur return dmap_cur_refined, dmap_cur_refined, BV_cur, BV_cur elif not m_misc.valid_dpv(BV_predict): return dmap_cur_refined, dmap_cur_refined, BV_cur, BV_cur else: # KV-Net # down_sample_rate = ref_frame.shape[3] / BV_cur.shape[3] ref_frame_dw = F.avg_pool2d(ref_frame, int(down_sample_rate)).cuda() src_frames_dw = [ F.avg_pool2d(src_frame_.unsqueeze(0), int(down_sample_rate)).cuda() for src_frame_ in src_frames.squeeze(0) ] Rs_src = [pose[:3, :3] for pose in src_cam_poses.squeeze(0)] ts_src = [pose[:3, 3] for pose in src_cam_poses.squeeze(0)] # Warp the src-frames to the ref. view # if mGPU: WAPRED_src_frames = warp_homo.warp_img_feats_mgpu( src_frames_dw, self.d_candi, Rs_src, ts_src, IntMs, unit_ray_Ms_2D) else: cam_intrin = cam_intrinsics[int(BatchIdx)] WAPRED_src_frames = warp_homo.warp_img_feats_v3( src_frames_dw, self.d_candi, Rs_src, ts_src, cam_intrin, ) ref_frame_dw_rep = torch.transpose( ref_frame_dw.repeat([len(self.d_candi), 1, 1, 1]), 0, 1) # Input to the KV-net # kvnet_in_vol = torch.cat( (torch.cat(tuple(WAPRED_src_frames), dim=0), ref_frame_dw_rep, BV_cur - BV_predict), dim=0).unsqueeze(0) # Run KV-net # BV_gain = self.kv_net(kvnet_in_vol) # Add back to BV_predict # DPV = torch.squeeze(BV_gain, dim=1) + BV_predict DPV = F.log_softmax(DPV, dim=1) if self.if_refined: dmap_lowres = m_misc.depth_val_regression( DPV, self.d_candi, BV_log=True).unsqueeze(0) if self.refineNet_name == 'DGF': dmap_refined = self.r_net(dmap_lowres, ref_frame) elif self.refineNet_name == 'DPV': dmap_refined = self.r_net(torch.exp(DPV), img_features=d_net_features) else: dmap_refined = -1 return dmap_cur_refined, dmap_refined, BV_cur, DPV