def main(args): my_devices = torch.device('cuda:' + str(args.gpu_id)) '''Create Folders''' exp_root_dir = Path(os.path.join('./logs/nerfmm', args.scene_name)) exp_root_dir.mkdir(parents=True, exist_ok=True) experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args))) experiment_dir.mkdir(parents=True, exist_ok=True) shutil.copy('./models/nerf_models.py', experiment_dir) shutil.copy('./models/intrinsics.py', experiment_dir) shutil.copy('./models/poses.py', experiment_dir) shutil.copy('./tasks/nerfmm/train.py', experiment_dir) if args.store_pose_history: pose_history_dir = Path(os.path.join(experiment_dir, 'pose_history')) pose_history_dir.mkdir(parents=True, exist_ok=True) '''LOG''' logger = logging.getLogger() logger.setLevel(logging.INFO) file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt')) file_handler.setLevel(logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.WARNING) logger.addHandler(file_handler) logger.addHandler(stream_handler) logger.info(args) '''Summary Writer''' writer = SummaryWriter(log_dir=str(experiment_dir)) '''Data Loading''' scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir, scene_name=args.scene_name, data_type='train', res_ratio=args.resize_ratio, num_img_to_load=args.train_img_num, skip=args.train_skip, use_ndc=args.use_ndc) # The COLMAP eval poses are not in the same camera space that we learned so we can only check NVS # with a 4x4 identity pose. eval_c2ws = torch.eye(4).unsqueeze(0).float() # (1, 4, 4) print('Train with {0:6d} images.'.format(scene_train.imgs.shape[0])) '''Model Loading''' pos_enc_in_dims = (2 * args.pos_enc_levels + int(args.pos_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 if args.use_dir_enc: dir_enc_in_dims = (2 * args.dir_enc_levels + int(args.dir_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 else: dir_enc_in_dims = 0 model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims) if args.multi_gpu: model = torch.nn.DataParallel(model).to(device=my_devices) else: model = model.to(device=my_devices) # learn focal parameter if args.start_refine_focal_epoch > -1: focal_net = LearnFocal(scene_train.H, scene_train.W, args.learn_focal, args.fx_only, order=args.focal_order, init_focal=scene_train.focal) else: focal_net = LearnFocal(scene_train.H, scene_train.W, args.learn_focal, args.fx_only, order=args.focal_order) if args.multi_gpu: focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices) else: focal_net = focal_net.to(device=my_devices) # learn pose for each image if args.start_refine_pose_epoch > -1: pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, scene_train.c2ws) else: pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, None) if args.multi_gpu: pose_param_net = torch.nn.DataParallel(pose_param_net).to( device=my_devices) else: pose_param_net = pose_param_net.to(device=my_devices) '''Set Optimiser''' optimizer_nerf = torch.optim.Adam(model.parameters(), lr=args.nerf_lr) optimizer_focal = torch.optim.Adam(focal_net.parameters(), lr=args.focal_lr) optimizer_pose = torch.optim.Adam(pose_param_net.parameters(), lr=args.pose_lr) scheduler_nerf = torch.optim.lr_scheduler.MultiStepLR( optimizer_nerf, milestones=args.nerf_milestones, gamma=args.nerf_lr_gamma) scheduler_focal = torch.optim.lr_scheduler.MultiStepLR( optimizer_focal, milestones=args.focal_milestones, gamma=args.focal_lr_gamma) scheduler_pose = torch.optim.lr_scheduler.MultiStepLR( optimizer_pose, milestones=args.pose_milestones, gamma=args.pose_lr_gamma) '''Training''' for epoch_i in tqdm(range(args.epoch), desc='epochs'): rgb_act_fn = torch.sigmoid train_epoch_losses = train_one_epoch(scene_train, optimizer_nerf, optimizer_focal, optimizer_pose, model, focal_net, pose_param_net, my_devices, args, rgb_act_fn, epoch_i) train_L2_loss = train_epoch_losses['L2'] scheduler_nerf.step() scheduler_focal.step() scheduler_pose.step() train_psnr = mse2psnr(train_L2_loss) writer.add_scalar('train/mse', train_L2_loss, epoch_i) writer.add_scalar('train/psnr', train_psnr, epoch_i) writer.add_scalar('train/lr', scheduler_nerf.get_lr()[0], epoch_i) logger.info('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format( epoch_i, train_L2_loss, train_psnr)) tqdm.write('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format( epoch_i, train_L2_loss, train_psnr)) pose_history_milestone = list(range(0, 100, 5)) + list( range(100, 1000, 100)) + list(range(1000, 10000, 1000)) if epoch_i in pose_history_milestone: with torch.no_grad(): if args.store_pose_history: store_current_pose(pose_param_net, pose_history_dir, epoch_i) if epoch_i % args.eval_cam_interval == 0 and epoch_i > 0: with torch.no_grad(): eval_stats_tran, eval_stats_rot, eval_stats_scale = eval_one_epoch_traj( scene_train, pose_param_net) writer.add_scalar('eval/traj/translation', eval_stats_tran['mean'], epoch_i) writer.add_scalar('eval/traj/rotation', eval_stats_rot['mean'], epoch_i) writer.add_scalar('eval/traj/scale', eval_stats_scale['mean'], epoch_i) logger.info( '{0:6d} ep Traj Err: translation: {1:.6f}, rotation: {2:.2f} deg, scale: {3:.2f}' .format(epoch_i, eval_stats_tran['mean'], eval_stats_rot['mean'], eval_stats_scale['mean'])) tqdm.write( '{0:6d} ep Traj Err: translation: {1:.6f}, rotation: {2:.2f} deg, scale: {3:.2f}' .format(epoch_i, eval_stats_tran['mean'], eval_stats_rot['mean'], eval_stats_scale['mean'])) fxfy = focal_net(0) tqdm.write( 'Est fx: {0:.2f}, fy {1:.2f}, COLMAP focal: {2:.2f}'. format(fxfy[0].item(), fxfy[1].item(), scene_train.focal)) logger.info( 'Est fx: {0:.2f}, fy {1:.2f}, COLMAP focal: {2:.2f}'. format(fxfy[0].item(), fxfy[1].item(), scene_train.focal)) if torch.is_tensor(fxfy): L1_focal = torch.abs(fxfy - scene_train.focal).mean().item() else: L1_focal = np.abs(fxfy - scene_train.focal).mean() writer.add_scalar('eval/L1_focal', L1_focal, epoch_i) if epoch_i % args.eval_img_interval == 0 and epoch_i > 0: with torch.no_grad(): eval_one_epoch_img(eval_c2ws, scene_train, model, focal_net, pose_param_net, my_devices, args, epoch_i, writer, rgb_act_fn) # save the latest model. save_checkpoint(epoch_i, model, optimizer_nerf, experiment_dir, ckpt_name='latest_nerf') save_checkpoint(epoch_i, focal_net, optimizer_focal, experiment_dir, ckpt_name='latest_focal') save_checkpoint(epoch_i, pose_param_net, optimizer_pose, experiment_dir, ckpt_name='latest_pose') return
def main(args): my_devices = torch.device('cpu') '''Create Folders''' pose_out_dir = Path(os.path.join(args.ckpt_dir, 'pose_out')) pose_out_dir.mkdir(parents=True, exist_ok=True) '''Get COLMAP poses''' scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir, scene_name=args.scene_name, data_type='train', res_ratio=args.resize_ratio, num_img_to_load=args.train_img_num, skip=args.train_skip, use_ndc=True, load_img=False) # scale colmap poses to unit sphere ts_colmap = scene_train.c2ws[:, :3, 3] # (N, 3) scene_train.c2ws[:, :3, 3] /= pts_dist_max(ts_colmap) scene_train.c2ws[:, :3, 3] *= 2.0 '''Load scene meta''' H, W = scene_train.H, scene_train.W colmap_focal = scene_train.focal print('Intrinsic: H: {0:4d}, W: {1:4d}, COLMAP focal {2:.2f}.'.format( H, W, colmap_focal)) '''Model Loading''' if args.init_focal_colmap: focal_net = LearnFocal(H, W, args.learn_focal, args.fx_only, order=args.focal_order, init_focal=colmap_focal) else: focal_net = LearnFocal(H, W, args.learn_focal, args.fx_only, order=args.focal_order) # only load learned focal if we do not init with colmap focal focal_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_focal.pth'), focal_net, map_location=my_devices) fxfy = focal_net(0) print('COLMAP focal: {0:.2f}, learned fx: {1:.2f}, fy: {2:.2f}'.format( colmap_focal, fxfy[0].item(), fxfy[1].item())) if args.init_pose_colmap: pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, scene_train.c2ws) else: pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, None) pose_param_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_pose.pth'), pose_param_net, map_location=my_devices) '''Get all poses in (N, 4, 4)''' c2ws_est = torch.stack( [pose_param_net(i) for i in range(scene_train.N_imgs)]) # (N, 4, 4) c2ws_cmp = scene_train.c2ws # (N, 4, 4) # scale estimated poses to unit sphere ts_est = c2ws_est[:, :3, 3] # (N, 3) c2ws_est[:, :3, 3] /= pts_dist_max(ts_est) c2ws_est[:, :3, 3] *= 2.0 '''Define camera frustums''' frustum_length = 0.1 est_traj_color = np.array([39, 125, 161], dtype=np.float32) / 255 cmp_traj_color = np.array([249, 65, 68], dtype=np.float32) / 255 '''Align est traj to colmap traj''' c2ws_est_to_draw_align2cmp = c2ws_est.clone() if args.ATE_align: # Align learned poses to colmap poses c2ws_est_aligned = align_ate_c2b_use_a2b(c2ws_est, c2ws_cmp) # (N, 4, 4) c2ws_est_to_draw_align2cmp = c2ws_est_aligned # compute ate stats_tran_est, stats_rot_est, _ = compute_ate(c2ws_est_aligned, c2ws_cmp, align_a2b=None) print('From est to colmap: tran err {0:.3f}, rot err {1:.2f}'.format( stats_tran_est['mean'], stats_rot_est['mean'])) frustum_est_list = draw_camera_frustum_geometry( c2ws_est_to_draw_align2cmp.cpu().numpy(), H, W, fxfy[0], fxfy[1], frustum_length, est_traj_color) frustum_colmap_list = draw_camera_frustum_geometry(c2ws_cmp.cpu().numpy(), H, W, colmap_focal, colmap_focal, frustum_length, cmp_traj_color) geometry_to_draw = [] geometry_to_draw.append(frustum_est_list) geometry_to_draw.append(frustum_colmap_list) '''o3d for line drawing''' t_est_list = c2ws_est_to_draw_align2cmp[:, :3, 3] t_cmp_list = c2ws_cmp[:, :3, 3] '''line set to note pose correspondence between two trajs''' line_points = torch.cat([t_est_list, t_cmp_list], dim=0).cpu().numpy() # (2N, 3) line_ends = [[i, i + scene_train.N_imgs] for i in range(scene_train.N_imgs) ] # (N, 2) connect two end points. # line_color = np.zeros((scene_train.N_imgs, 3), dtype=np.float32) # line_color[:, 0] = 1.0 line_set = o3d.geometry.LineSet() line_set.points = o3d.utility.Vector3dVector(line_points) line_set.lines = o3d.utility.Vector2iVector(line_ends) # line_set.colors = o3d.utility.Vector3dVector(line_color) geometry_to_draw.append(line_set) o3d.visualization.draw_geometries(geometry_to_draw)
def main(args): my_devices = torch.device('cuda:' + str(args.gpu_id)) '''Create Folders''' test_dir = Path(os.path.join(args.ckpt_dir, 'render_spiral')) img_out_dir = Path(os.path.join(test_dir, 'img_out')) depth_out_dir = Path(os.path.join(test_dir, 'depth_out')) video_out_dir = Path(os.path.join(test_dir, 'video_out')) test_dir.mkdir(parents=True, exist_ok=True) img_out_dir.mkdir(parents=True, exist_ok=True) depth_out_dir.mkdir(parents=True, exist_ok=True) video_out_dir.mkdir(parents=True, exist_ok=True) '''Scene Meta''' scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir, scene_name=args.scene_name, data_type='train', res_ratio=args.resize_ratio, num_img_to_load=args.train_img_num, skip=args.train_skip, use_ndc=args.use_ndc, load_img=False) H, W = scene_train.H, scene_train.W colmap_focal = scene_train.focal near, far = scene_train.near, scene_train.far print('Intrinsic: H: {0:4d}, W: {1:4d}, COLMAP focal {2:.2f}.'.format( H, W, colmap_focal)) print('near: {0:.1f}, far: {1:.1f}.'.format(near, far)) '''Model Loading''' pos_enc_in_dims = (2 * args.pos_enc_levels + int(args.pos_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 if args.use_dir_enc: dir_enc_in_dims = (2 * args.dir_enc_levels + int(args.dir_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 else: dir_enc_in_dims = 0 model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims) if args.multi_gpu: model = torch.nn.DataParallel(model).to(device=my_devices) else: model = model.to(device=my_devices) model = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_nerf.pth'), model, map_location=my_devices) if args.init_focal_colmap: focal_net = LearnFocal(H, W, args.learn_focal, args.fx_only, order=args.focal_order, init_focal=colmap_focal) else: focal_net = LearnFocal(H, W, args.learn_focal, args.fx_only, order=args.focal_order) if args.multi_gpu: focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices) else: focal_net = focal_net.to(device=my_devices) # do not load learned focal if we use colmap focal if not args.init_focal_colmap: focal_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_focal.pth'), focal_net, map_location=my_devices) fxfy = focal_net(0) print('COLMAP focal: {0:.2f}, learned fx: {1:.2f}, fy: {2:.2f}'.format( colmap_focal, fxfy[0].item(), fxfy[1].item())) pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, None) if args.multi_gpu: pose_param_net = torch.nn.DataParallel(pose_param_net).to( device=my_devices) else: pose_param_net = pose_param_net.to(device=my_devices) pose_param_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_pose.pth'), pose_param_net, map_location=my_devices) learned_poses = torch.stack( [pose_param_net(i) for i in range(scene_train.N_imgs)]) '''Generate camera traj''' # This spiral camera traj code is modified from https://github.com/kwea123/nerf_pl. # hardcoded, this is numerically close to the formula # given in the original repo. Mathematically if near=1 # and far=infinity, then this number will converge to 4 N_novel_imgs = args.N_img_per_circle * args.N_circle_traj focus_depth = 3.5 radii = np.percentile(np.abs(learned_poses.cpu().numpy()[:, :3, 3]), args.spiral_mag_percent, axis=0) # (3,) radii *= np.array(args.spiral_axis_scale) c2ws = create_spiral_poses(radii, focus_depth, n_circle=args.N_circle_traj, n_poses=N_novel_imgs) c2ws = torch.from_numpy(c2ws).float() # (N, 3, 4) c2ws = convert3x4_4x4(c2ws) # (N, 4, 4) '''Render''' result = test_one_epoch(H, W, focal_net, c2ws, near, far, model, my_devices, args) imgs = result['imgs'] depths = result['depths'] '''Write to folder''' imgs = (imgs.cpu().numpy() * 255).astype(np.uint8) depths = (depths.cpu().numpy() * 200).astype(np.uint8) # far is 1.0 in NDC for i in range(c2ws.shape[0]): imageio.imwrite(os.path.join(img_out_dir, str(i).zfill(4) + '.png'), imgs[i]) imageio.imwrite(os.path.join(depth_out_dir, str(i).zfill(4) + '.png'), depths[i]) imageio.mimwrite(os.path.join(video_out_dir, 'img.mp4'), imgs, fps=30, quality=9) imageio.mimwrite(os.path.join(video_out_dir, 'depth.mp4'), depths, fps=30, quality=9) imageio.mimwrite(os.path.join(video_out_dir, 'img.gif'), imgs, fps=30) imageio.mimwrite(os.path.join(video_out_dir, 'depth.gif'), depths, fps=30) return
def main(args): my_devices = torch.device('cuda:' + str(args.gpu_id)) '''Create Folders''' exp_root_dir = Path(os.path.join('./logs/any_folder', args.scene_name)) exp_root_dir.mkdir(parents=True, exist_ok=True) experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args))) experiment_dir.mkdir(parents=True, exist_ok=True) shutil.copy('./models/nerf_models.py', experiment_dir) shutil.copy('./models/intrinsics.py', experiment_dir) shutil.copy('./models/poses.py', experiment_dir) shutil.copy('./tasks/any_folder/train.py', experiment_dir) '''LOG''' logger = logging.getLogger() logger.setLevel(logging.INFO) file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt')) file_handler.setLevel(logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.WARNING) logger.addHandler(file_handler) logger.addHandler(stream_handler) logger.info(args) '''Summary Writer''' writer = SummaryWriter(log_dir=str(experiment_dir)) '''Data Loading''' scene_train = DataLoaderAnyFolder(base_dir=args.base_dir, scene_name=args.scene_name, res_ratio=args.resize_ratio, num_img_to_load=args.train_img_num, start=args.train_start, end=args.train_end, skip=args.train_skip, load_sorted=args.train_load_sorted) print('Train with {0:6d} images.'.format(scene_train.imgs.shape[0])) # We have no eval pose in this any_folder task. Eval with a 4x4 identity pose. eval_c2ws = torch.eye(4).unsqueeze(0).float() # (1, 4, 4) '''Model Loading''' pos_enc_in_dims = (2 * args.pos_enc_levels + int(args.pos_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 if args.use_dir_enc: dir_enc_in_dims = (2 * args.dir_enc_levels + int(args.dir_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 else: dir_enc_in_dims = 0 model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims) if args.multi_gpu: model = torch.nn.DataParallel(model).to(device=my_devices) else: model = model.to(device=my_devices) # learn focal parameter focal_net = LearnFocal(scene_train.H, scene_train.W, args.learn_focal, args.fx_only, order=args.focal_order) if args.multi_gpu: focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices) else: focal_net = focal_net.to(device=my_devices) # learn pose for each image pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, None) if args.multi_gpu: pose_param_net = torch.nn.DataParallel(pose_param_net).to( device=my_devices) else: pose_param_net = pose_param_net.to(device=my_devices) '''Set Optimiser''' optimizer_nerf = torch.optim.Adam(model.parameters(), lr=args.nerf_lr) optimizer_focal = torch.optim.Adam(focal_net.parameters(), lr=args.focal_lr) optimizer_pose = torch.optim.Adam(pose_param_net.parameters(), lr=args.pose_lr) scheduler_nerf = torch.optim.lr_scheduler.MultiStepLR( optimizer_nerf, milestones=args.nerf_milestones, gamma=args.nerf_lr_gamma) scheduler_focal = torch.optim.lr_scheduler.MultiStepLR( optimizer_focal, milestones=args.focal_milestones, gamma=args.focal_lr_gamma) scheduler_pose = torch.optim.lr_scheduler.MultiStepLR( optimizer_pose, milestones=args.pose_milestones, gamma=args.pose_lr_gamma) '''Training''' for epoch_i in tqdm(range(args.epoch), desc='epochs'): rgb_act_fn = torch.sigmoid train_epoch_losses = train_one_epoch(scene_train, optimizer_nerf, optimizer_focal, optimizer_pose, model, focal_net, pose_param_net, my_devices, args, rgb_act_fn) train_L2_loss = train_epoch_losses['L2'] scheduler_nerf.step() scheduler_focal.step() scheduler_pose.step() train_psnr = mse2psnr(train_L2_loss) writer.add_scalar('train/mse', train_L2_loss, epoch_i) writer.add_scalar('train/psnr', train_psnr, epoch_i) writer.add_scalar('train/lr', scheduler_nerf.get_lr()[0], epoch_i) logger.info('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format( epoch_i, train_L2_loss, train_psnr)) tqdm.write('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format( epoch_i, train_L2_loss, train_psnr)) if epoch_i % args.eval_interval == 0 and epoch_i > 0: with torch.no_grad(): eval_one_epoch(eval_c2ws, scene_train, model, focal_net, pose_param_net, my_devices, args, epoch_i, writer, rgb_act_fn) fxfy = focal_net(0) tqdm.write('Est fx: {0:.2f}, fy {1:.2f}'.format( fxfy[0].item(), fxfy[1].item())) logger.info('Est fx: {0:.2f}, fy {1:.2f}'.format( fxfy[0].item(), fxfy[1].item())) # save the latest model save_checkpoint(epoch_i, model, optimizer_nerf, experiment_dir, ckpt_name='latest_nerf') save_checkpoint(epoch_i, focal_net, optimizer_focal, experiment_dir, ckpt_name='latest_focal') save_checkpoint(epoch_i, pose_param_net, optimizer_pose, experiment_dir, ckpt_name='latest_pose') return
def main(args): my_devices = torch.device('cuda:' + str(args.gpu_id)) '''Create Folders''' test_dir = Path(os.path.join(args.ckpt_dir, 'render_' + args.type_to_eval)) img_out_dir = Path(os.path.join(test_dir, 'img_out')) depth_out_dir = Path(os.path.join(test_dir, 'depth_out')) video_out_dir = Path(os.path.join(test_dir, 'video_out')) eval_pose_out_dir = Path(os.path.join(test_dir, 'eval_pose_out')) test_dir.mkdir(parents=True, exist_ok=True) img_out_dir.mkdir(parents=True, exist_ok=True) depth_out_dir.mkdir(parents=True, exist_ok=True) video_out_dir.mkdir(parents=True, exist_ok=True) eval_pose_out_dir.mkdir(parents=True, exist_ok=True) '''LOG''' logger = logging.getLogger() logger.setLevel(logging.WARNING) file_handler = logging.FileHandler(os.path.join(eval_pose_out_dir, 'log.txt')) file_handler.setLevel(logging.INFO) logger.addHandler(file_handler) logger.info(args) '''Summary Writer''' writer = SummaryWriter(log_dir=str(eval_pose_out_dir)) '''Load data''' scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir, scene_name=args.scene_name, data_type='train', res_ratio=args.resize_ratio, num_img_to_load=args.train_img_num, skip=args.train_skip, use_ndc=args.use_ndc, load_img=args.type_to_eval == 'train') # only load imgs if eval train set. print('Intrinsic: H: {0:4d}, W: {1:4d}, GT focal {2:.2f}.'.format(scene_train.H, scene_train.W, scene_train.focal)) if args.type_to_eval == 'train': scene_eval = scene_train else: scene_eval = DataLoaderWithCOLMAP(base_dir=args.base_dir, scene_name=args.scene_name, data_type='val', res_ratio=args.resize_ratio, num_img_to_load=args.eval_img_num, skip=args.eval_skip, use_ndc=args.use_ndc) '''Model Loading''' pos_enc_in_dims = (2 * args.pos_enc_levels + int(args.pos_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 if args.use_dir_enc: dir_enc_in_dims = (2 * args.dir_enc_levels + int(args.dir_enc_inc_in)) * 3 # (2L + 0 or 1) * 3 else: dir_enc_in_dims = 0 model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims) if args.multi_gpu: model = torch.nn.DataParallel(model).to(device=my_devices) else: model = model.to(device=my_devices) model = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_nerf.pth'), model, map_location=my_devices) if args.init_focal_from == 'colmap': focal_net = LearnFocal(scene_train.H, scene_train.W, args.learn_focal, args.fx_only, order=args.focal_order, init_focal=scene_train.focal) else: focal_net = LearnFocal(scene_train.H, scene_train.W, args.learn_focal, args.fx_only, order=args.focal_order) if args.multi_gpu: focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices) else: focal_net = focal_net.to(device=my_devices) # load learned focal if we did not init focal with something if args.init_focal_from == 'none': focal_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_focal.pth'), focal_net, map_location=my_devices) fxfy = focal_net(0) if 'blender/' in args.scene_name: print('GT: fx {0:.2f} fy {1:.2f}, learned: fx {2:.2f}, fy {3:.2f}, COLMAP: {4:.2f}'.format( scene_train.gt_fx, scene_train.gt_fy, fxfy[0].item(), fxfy[1].item(), scene_train.focal)) else: print('COLMAP: {0:.2f}, learned: fx {1:.2f}, fy {2:.2f}, '.format( scene_train.focal, fxfy[0].item(), fxfy[1].item())) if args.init_pose_from == 'colmap': learned_pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, scene_train.c2ws) else: learned_pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, None) if args.multi_gpu: learned_pose_param_net = torch.nn.DataParallel(learned_pose_param_net).to(device=my_devices) else: learned_pose_param_net = learned_pose_param_net.to(device=my_devices) learned_pose_param_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_pose.pth'), learned_pose_param_net, map_location=my_devices) # We optimise poses for validation images while freezing learned focal length and trained nerf model. # This step is only required when we compute evaluation metrics, as the space of learned poses # is different from the space of colmap poses. if args.type_to_eval == 'train': eval_pose_param_net = learned_pose_param_net else: with torch.no_grad(): # compuate a scale between two learned traj and colmap traj init_c2ws = scene_eval.c2ws learned_c2ws_train = torch.stack([learned_pose_param_net(i) for i in range(scene_train.N_imgs)]) # (N, 4, 4) colmap_c2ws_train = scene_train.c2ws # (N, 4, 4) init_c2ws, scale_colmap2est = align_scale_c2b_use_a2b(colmap_c2ws_train, learned_c2ws_train, init_c2ws) eval_pose_param_net = LearnPose(scene_eval.N_imgs, args.opt_eval_R, args.opt_eval_t, init_c2ws) if args.multi_gpu: eval_pose_param_net = torch.nn.DataParallel(eval_pose_param_net).to(device=my_devices) else: eval_pose_param_net = eval_pose_param_net.to(device=my_devices) '''Set Optimiser''' optimizer_eval_pose = torch.optim.Adam(eval_pose_param_net.parameters(), lr=args.opt_eval_lr) scheduler_eval_pose = torch.optim.lr_scheduler.MultiStepLR(optimizer_eval_pose, milestones=args.eval_pose_milestones, gamma=args.eval_pose_lr_gamma) '''Optimise eval poses''' if args.type_to_eval != 'train': for epoch_i in tqdm(range(args.opt_pose_epoch), desc='optimising eval'): mean_losses = opt_eval_pose_one_epoch(model, focal_net, eval_pose_param_net, scene_eval, optimizer_eval_pose, my_devices) opt_L2_loss = mean_losses['L2'] opt_pose_psnr = mse2psnr(opt_L2_loss) scheduler_eval_pose.step() writer.add_scalar('opt/mse', opt_L2_loss, epoch_i) writer.add_scalar('opt/psnr', opt_pose_psnr, epoch_i) logger.info('{0:6d} ep: Opt: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(epoch_i, opt_L2_loss, opt_pose_psnr)) tqdm.write('{0:6d} ep: Opt: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(epoch_i, opt_L2_loss, opt_pose_psnr)) with torch.no_grad(): '''Compute ATE''' stats_tran, stats_rot, stats_scale = eval_one_epoch_traj(scene_train, learned_pose_param_net) print('------------------ ATE statistic ------------------') print('Traj Err: translation: {0:.6f}, rotation: {1:.2f} deg, scale: {2:.2f}'.format(stats_tran['mean'], stats_rot['mean'], stats_scale['mean'])) print('-------------------------------------------------') '''Final Render''' result = eval_one_epoch_img(scene_eval, model, focal_net, eval_pose_param_net, my_devices, args, logger) imgs = result['imgs'] depths = result['depths'] '''Write to folder''' imgs = (imgs.cpu().numpy() * 255).astype(np.uint8) depths = (depths.cpu().numpy() * 10).astype(np.uint8) for i in range(scene_eval.N_imgs): imageio.imwrite(os.path.join(img_out_dir, str(i).zfill(4) + '.png'), imgs[i]) imageio.imwrite(os.path.join(depth_out_dir, str(i).zfill(4) + '.png'), depths[i]) imageio.mimwrite(os.path.join(video_out_dir, 'img.mp4'), imgs, fps=30, quality=9) imageio.mimwrite(os.path.join(video_out_dir, 'depth.mp4'), depths, fps=30, quality=9) imageio.mimwrite(os.path.join(video_out_dir, 'img.gif'), imgs, fps=30) imageio.mimwrite(os.path.join(video_out_dir, 'depth.gif'), depths, fps=30) return