示例#1
0
def main(args):
    my_devices = torch.device('cuda:' + str(args.gpu_id))
    '''Create Folders'''
    exp_root_dir = Path(os.path.join('./logs/nerfmm', args.scene_name))
    exp_root_dir.mkdir(parents=True, exist_ok=True)
    experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args)))
    experiment_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy('./models/nerf_models.py', experiment_dir)
    shutil.copy('./models/intrinsics.py', experiment_dir)
    shutil.copy('./models/poses.py', experiment_dir)
    shutil.copy('./tasks/nerfmm/train.py', experiment_dir)

    if args.store_pose_history:
        pose_history_dir = Path(os.path.join(experiment_dir, 'pose_history'))
        pose_history_dir.mkdir(parents=True, exist_ok=True)
    '''LOG'''
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt'))
    file_handler.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.WARNING)
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.info(args)
    '''Summary Writer'''
    writer = SummaryWriter(log_dir=str(experiment_dir))
    '''Data Loading'''
    scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir,
                                       scene_name=args.scene_name,
                                       data_type='train',
                                       res_ratio=args.resize_ratio,
                                       num_img_to_load=args.train_img_num,
                                       skip=args.train_skip,
                                       use_ndc=args.use_ndc)

    # The COLMAP eval poses are not in the same camera space that we learned so we can only check NVS
    # with a 4x4 identity pose.
    eval_c2ws = torch.eye(4).unsqueeze(0).float()  # (1, 4, 4)

    print('Train with {0:6d} images.'.format(scene_train.imgs.shape[0]))
    '''Model Loading'''
    pos_enc_in_dims = (2 * args.pos_enc_levels +
                       int(args.pos_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    if args.use_dir_enc:
        dir_enc_in_dims = (2 * args.dir_enc_levels +
                           int(args.dir_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    else:
        dir_enc_in_dims = 0

    model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims)
    if args.multi_gpu:
        model = torch.nn.DataParallel(model).to(device=my_devices)
    else:
        model = model.to(device=my_devices)

    # learn focal parameter
    if args.start_refine_focal_epoch > -1:
        focal_net = LearnFocal(scene_train.H,
                               scene_train.W,
                               args.learn_focal,
                               args.fx_only,
                               order=args.focal_order,
                               init_focal=scene_train.focal)
    else:
        focal_net = LearnFocal(scene_train.H,
                               scene_train.W,
                               args.learn_focal,
                               args.fx_only,
                               order=args.focal_order)
    if args.multi_gpu:
        focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices)
    else:
        focal_net = focal_net.to(device=my_devices)

    # learn pose for each image
    if args.start_refine_pose_epoch > -1:
        pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R,
                                   args.learn_t, scene_train.c2ws)
    else:
        pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R,
                                   args.learn_t, None)
    if args.multi_gpu:
        pose_param_net = torch.nn.DataParallel(pose_param_net).to(
            device=my_devices)
    else:
        pose_param_net = pose_param_net.to(device=my_devices)
    '''Set Optimiser'''
    optimizer_nerf = torch.optim.Adam(model.parameters(), lr=args.nerf_lr)
    optimizer_focal = torch.optim.Adam(focal_net.parameters(),
                                       lr=args.focal_lr)
    optimizer_pose = torch.optim.Adam(pose_param_net.parameters(),
                                      lr=args.pose_lr)

    scheduler_nerf = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_nerf,
        milestones=args.nerf_milestones,
        gamma=args.nerf_lr_gamma)
    scheduler_focal = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_focal,
        milestones=args.focal_milestones,
        gamma=args.focal_lr_gamma)
    scheduler_pose = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_pose,
        milestones=args.pose_milestones,
        gamma=args.pose_lr_gamma)
    '''Training'''
    for epoch_i in tqdm(range(args.epoch), desc='epochs'):
        rgb_act_fn = torch.sigmoid
        train_epoch_losses = train_one_epoch(scene_train, optimizer_nerf,
                                             optimizer_focal, optimizer_pose,
                                             model, focal_net, pose_param_net,
                                             my_devices, args, rgb_act_fn,
                                             epoch_i)
        train_L2_loss = train_epoch_losses['L2']
        scheduler_nerf.step()
        scheduler_focal.step()
        scheduler_pose.step()

        train_psnr = mse2psnr(train_L2_loss)
        writer.add_scalar('train/mse', train_L2_loss, epoch_i)
        writer.add_scalar('train/psnr', train_psnr, epoch_i)
        writer.add_scalar('train/lr', scheduler_nerf.get_lr()[0], epoch_i)
        logger.info('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))
        tqdm.write('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))

        pose_history_milestone = list(range(0, 100, 5)) + list(
            range(100, 1000, 100)) + list(range(1000, 10000, 1000))
        if epoch_i in pose_history_milestone:
            with torch.no_grad():
                if args.store_pose_history:
                    store_current_pose(pose_param_net, pose_history_dir,
                                       epoch_i)

        if epoch_i % args.eval_cam_interval == 0 and epoch_i > 0:
            with torch.no_grad():
                eval_stats_tran, eval_stats_rot, eval_stats_scale = eval_one_epoch_traj(
                    scene_train, pose_param_net)
                writer.add_scalar('eval/traj/translation',
                                  eval_stats_tran['mean'], epoch_i)
                writer.add_scalar('eval/traj/rotation', eval_stats_rot['mean'],
                                  epoch_i)
                writer.add_scalar('eval/traj/scale', eval_stats_scale['mean'],
                                  epoch_i)

                logger.info(
                    '{0:6d} ep Traj Err: translation: {1:.6f}, rotation: {2:.2f} deg, scale: {3:.2f}'
                    .format(epoch_i, eval_stats_tran['mean'],
                            eval_stats_rot['mean'], eval_stats_scale['mean']))
                tqdm.write(
                    '{0:6d} ep Traj Err: translation: {1:.6f}, rotation: {2:.2f} deg, scale: {3:.2f}'
                    .format(epoch_i, eval_stats_tran['mean'],
                            eval_stats_rot['mean'], eval_stats_scale['mean']))

                fxfy = focal_net(0)
                tqdm.write(
                    'Est fx: {0:.2f}, fy {1:.2f}, COLMAP focal: {2:.2f}'.
                    format(fxfy[0].item(), fxfy[1].item(), scene_train.focal))
                logger.info(
                    'Est fx: {0:.2f}, fy {1:.2f}, COLMAP focal: {2:.2f}'.
                    format(fxfy[0].item(), fxfy[1].item(), scene_train.focal))
                if torch.is_tensor(fxfy):
                    L1_focal = torch.abs(fxfy -
                                         scene_train.focal).mean().item()
                else:
                    L1_focal = np.abs(fxfy - scene_train.focal).mean()
                writer.add_scalar('eval/L1_focal', L1_focal, epoch_i)

        if epoch_i % args.eval_img_interval == 0 and epoch_i > 0:
            with torch.no_grad():
                eval_one_epoch_img(eval_c2ws, scene_train, model, focal_net,
                                   pose_param_net, my_devices, args, epoch_i,
                                   writer, rgb_act_fn)

                # save the latest model.
                save_checkpoint(epoch_i,
                                model,
                                optimizer_nerf,
                                experiment_dir,
                                ckpt_name='latest_nerf')
                save_checkpoint(epoch_i,
                                focal_net,
                                optimizer_focal,
                                experiment_dir,
                                ckpt_name='latest_focal')
                save_checkpoint(epoch_i,
                                pose_param_net,
                                optimizer_pose,
                                experiment_dir,
                                ckpt_name='latest_pose')
    return
示例#2
0
def main(args):
    my_devices = torch.device('cuda:' + str(args.gpu_id))
    '''Create Folders'''
    test_dir = Path(os.path.join(args.ckpt_dir, 'render_spiral'))
    img_out_dir = Path(os.path.join(test_dir, 'img_out'))
    depth_out_dir = Path(os.path.join(test_dir, 'depth_out'))
    video_out_dir = Path(os.path.join(test_dir, 'video_out'))
    test_dir.mkdir(parents=True, exist_ok=True)
    img_out_dir.mkdir(parents=True, exist_ok=True)
    depth_out_dir.mkdir(parents=True, exist_ok=True)
    video_out_dir.mkdir(parents=True, exist_ok=True)
    '''Scene Meta'''
    scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir,
                                       scene_name=args.scene_name,
                                       data_type='train',
                                       res_ratio=args.resize_ratio,
                                       num_img_to_load=args.train_img_num,
                                       skip=args.train_skip,
                                       use_ndc=args.use_ndc,
                                       load_img=False)

    H, W = scene_train.H, scene_train.W
    colmap_focal = scene_train.focal
    near, far = scene_train.near, scene_train.far

    print('Intrinsic: H: {0:4d}, W: {1:4d}, COLMAP focal {2:.2f}.'.format(
        H, W, colmap_focal))
    print('near: {0:.1f}, far: {1:.1f}.'.format(near, far))
    '''Model Loading'''
    pos_enc_in_dims = (2 * args.pos_enc_levels +
                       int(args.pos_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    if args.use_dir_enc:
        dir_enc_in_dims = (2 * args.dir_enc_levels +
                           int(args.dir_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    else:
        dir_enc_in_dims = 0

    model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims)
    if args.multi_gpu:
        model = torch.nn.DataParallel(model).to(device=my_devices)
    else:
        model = model.to(device=my_devices)
    model = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_nerf.pth'),
                             model,
                             map_location=my_devices)

    if args.init_focal_colmap:
        focal_net = LearnFocal(H,
                               W,
                               args.learn_focal,
                               args.fx_only,
                               order=args.focal_order,
                               init_focal=colmap_focal)
    else:
        focal_net = LearnFocal(H,
                               W,
                               args.learn_focal,
                               args.fx_only,
                               order=args.focal_order)
    if args.multi_gpu:
        focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices)
    else:
        focal_net = focal_net.to(device=my_devices)
    # do not load learned focal if we use colmap focal
    if not args.init_focal_colmap:
        focal_net = load_ckpt_to_net(os.path.join(args.ckpt_dir,
                                                  'latest_focal.pth'),
                                     focal_net,
                                     map_location=my_devices)
    fxfy = focal_net(0)
    print('COLMAP focal: {0:.2f}, learned fx: {1:.2f}, fy: {2:.2f}'.format(
        colmap_focal, fxfy[0].item(), fxfy[1].item()))

    pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t,
                               None)
    if args.multi_gpu:
        pose_param_net = torch.nn.DataParallel(pose_param_net).to(
            device=my_devices)
    else:
        pose_param_net = pose_param_net.to(device=my_devices)
    pose_param_net = load_ckpt_to_net(os.path.join(args.ckpt_dir,
                                                   'latest_pose.pth'),
                                      pose_param_net,
                                      map_location=my_devices)

    learned_poses = torch.stack(
        [pose_param_net(i) for i in range(scene_train.N_imgs)])
    '''Generate camera traj'''
    # This spiral camera traj code is modified from https://github.com/kwea123/nerf_pl.
    # hardcoded, this is numerically close to the formula
    # given in the original repo. Mathematically if near=1
    # and far=infinity, then this number will converge to 4
    N_novel_imgs = args.N_img_per_circle * args.N_circle_traj
    focus_depth = 3.5
    radii = np.percentile(np.abs(learned_poses.cpu().numpy()[:, :3, 3]),
                          args.spiral_mag_percent,
                          axis=0)  # (3,)
    radii *= np.array(args.spiral_axis_scale)
    c2ws = create_spiral_poses(radii,
                               focus_depth,
                               n_circle=args.N_circle_traj,
                               n_poses=N_novel_imgs)
    c2ws = torch.from_numpy(c2ws).float()  # (N, 3, 4)
    c2ws = convert3x4_4x4(c2ws)  # (N, 4, 4)
    '''Render'''
    result = test_one_epoch(H, W, focal_net, c2ws, near, far, model,
                            my_devices, args)
    imgs = result['imgs']
    depths = result['depths']
    '''Write to folder'''
    imgs = (imgs.cpu().numpy() * 255).astype(np.uint8)
    depths = (depths.cpu().numpy() * 200).astype(np.uint8)  # far is 1.0 in NDC

    for i in range(c2ws.shape[0]):
        imageio.imwrite(os.path.join(img_out_dir,
                                     str(i).zfill(4) + '.png'), imgs[i])
        imageio.imwrite(os.path.join(depth_out_dir,
                                     str(i).zfill(4) + '.png'), depths[i])

    imageio.mimwrite(os.path.join(video_out_dir, 'img.mp4'),
                     imgs,
                     fps=30,
                     quality=9)
    imageio.mimwrite(os.path.join(video_out_dir, 'depth.mp4'),
                     depths,
                     fps=30,
                     quality=9)

    imageio.mimwrite(os.path.join(video_out_dir, 'img.gif'), imgs, fps=30)
    imageio.mimwrite(os.path.join(video_out_dir, 'depth.gif'), depths, fps=30)

    return
示例#3
0
def main(args):
    my_devices = torch.device('cuda:' + str(args.gpu_id))
    '''Create Folders'''
    exp_root_dir = Path(os.path.join('./logs/any_folder', args.scene_name))
    exp_root_dir.mkdir(parents=True, exist_ok=True)
    experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args)))
    experiment_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy('./models/nerf_models.py', experiment_dir)
    shutil.copy('./models/intrinsics.py', experiment_dir)
    shutil.copy('./models/poses.py', experiment_dir)
    shutil.copy('./tasks/any_folder/train.py', experiment_dir)
    '''LOG'''
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt'))
    file_handler.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.WARNING)
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.info(args)
    '''Summary Writer'''
    writer = SummaryWriter(log_dir=str(experiment_dir))
    '''Data Loading'''
    scene_train = DataLoaderAnyFolder(base_dir=args.base_dir,
                                      scene_name=args.scene_name,
                                      res_ratio=args.resize_ratio,
                                      num_img_to_load=args.train_img_num,
                                      start=args.train_start,
                                      end=args.train_end,
                                      skip=args.train_skip,
                                      load_sorted=args.train_load_sorted)

    print('Train with {0:6d} images.'.format(scene_train.imgs.shape[0]))

    # We have no eval pose in this any_folder task. Eval with a 4x4 identity pose.
    eval_c2ws = torch.eye(4).unsqueeze(0).float()  # (1, 4, 4)
    '''Model Loading'''
    pos_enc_in_dims = (2 * args.pos_enc_levels +
                       int(args.pos_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    if args.use_dir_enc:
        dir_enc_in_dims = (2 * args.dir_enc_levels +
                           int(args.dir_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    else:
        dir_enc_in_dims = 0

    model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims)
    if args.multi_gpu:
        model = torch.nn.DataParallel(model).to(device=my_devices)
    else:
        model = model.to(device=my_devices)

    # learn focal parameter
    focal_net = LearnFocal(scene_train.H,
                           scene_train.W,
                           args.learn_focal,
                           args.fx_only,
                           order=args.focal_order)
    if args.multi_gpu:
        focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices)
    else:
        focal_net = focal_net.to(device=my_devices)

    # learn pose for each image
    pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t,
                               None)
    if args.multi_gpu:
        pose_param_net = torch.nn.DataParallel(pose_param_net).to(
            device=my_devices)
    else:
        pose_param_net = pose_param_net.to(device=my_devices)
    '''Set Optimiser'''
    optimizer_nerf = torch.optim.Adam(model.parameters(), lr=args.nerf_lr)
    optimizer_focal = torch.optim.Adam(focal_net.parameters(),
                                       lr=args.focal_lr)
    optimizer_pose = torch.optim.Adam(pose_param_net.parameters(),
                                      lr=args.pose_lr)

    scheduler_nerf = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_nerf,
        milestones=args.nerf_milestones,
        gamma=args.nerf_lr_gamma)
    scheduler_focal = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_focal,
        milestones=args.focal_milestones,
        gamma=args.focal_lr_gamma)
    scheduler_pose = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_pose,
        milestones=args.pose_milestones,
        gamma=args.pose_lr_gamma)
    '''Training'''
    for epoch_i in tqdm(range(args.epoch), desc='epochs'):
        rgb_act_fn = torch.sigmoid
        train_epoch_losses = train_one_epoch(scene_train, optimizer_nerf,
                                             optimizer_focal, optimizer_pose,
                                             model, focal_net, pose_param_net,
                                             my_devices, args, rgb_act_fn)
        train_L2_loss = train_epoch_losses['L2']
        scheduler_nerf.step()
        scheduler_focal.step()
        scheduler_pose.step()

        train_psnr = mse2psnr(train_L2_loss)
        writer.add_scalar('train/mse', train_L2_loss, epoch_i)
        writer.add_scalar('train/psnr', train_psnr, epoch_i)
        writer.add_scalar('train/lr', scheduler_nerf.get_lr()[0], epoch_i)
        logger.info('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))
        tqdm.write('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))

        if epoch_i % args.eval_interval == 0 and epoch_i > 0:
            with torch.no_grad():
                eval_one_epoch(eval_c2ws, scene_train, model, focal_net,
                               pose_param_net, my_devices, args, epoch_i,
                               writer, rgb_act_fn)

                fxfy = focal_net(0)
                tqdm.write('Est fx: {0:.2f}, fy {1:.2f}'.format(
                    fxfy[0].item(), fxfy[1].item()))
                logger.info('Est fx: {0:.2f}, fy {1:.2f}'.format(
                    fxfy[0].item(), fxfy[1].item()))

                # save the latest model
                save_checkpoint(epoch_i,
                                model,
                                optimizer_nerf,
                                experiment_dir,
                                ckpt_name='latest_nerf')
                save_checkpoint(epoch_i,
                                focal_net,
                                optimizer_focal,
                                experiment_dir,
                                ckpt_name='latest_focal')
                save_checkpoint(epoch_i,
                                pose_param_net,
                                optimizer_pose,
                                experiment_dir,
                                ckpt_name='latest_pose')
    return
示例#4
0
文件: eval.py 项目: liuguoyou/nerfmm
def main(args):
    my_devices = torch.device('cuda:' + str(args.gpu_id))

    '''Create Folders'''
    test_dir = Path(os.path.join(args.ckpt_dir, 'render_' + args.type_to_eval))
    img_out_dir = Path(os.path.join(test_dir, 'img_out'))
    depth_out_dir = Path(os.path.join(test_dir, 'depth_out'))
    video_out_dir = Path(os.path.join(test_dir, 'video_out'))
    eval_pose_out_dir = Path(os.path.join(test_dir, 'eval_pose_out'))
    test_dir.mkdir(parents=True, exist_ok=True)
    img_out_dir.mkdir(parents=True, exist_ok=True)
    depth_out_dir.mkdir(parents=True, exist_ok=True)
    video_out_dir.mkdir(parents=True, exist_ok=True)
    eval_pose_out_dir.mkdir(parents=True, exist_ok=True)

    '''LOG'''
    logger = logging.getLogger()
    logger.setLevel(logging.WARNING)
    file_handler = logging.FileHandler(os.path.join(eval_pose_out_dir, 'log.txt'))
    file_handler.setLevel(logging.INFO)
    logger.addHandler(file_handler)
    logger.info(args)

    '''Summary Writer'''
    writer = SummaryWriter(log_dir=str(eval_pose_out_dir))

    '''Load data'''
    scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir,
                                       scene_name=args.scene_name,
                                       data_type='train',
                                       res_ratio=args.resize_ratio,
                                       num_img_to_load=args.train_img_num,
                                       skip=args.train_skip,
                                       use_ndc=args.use_ndc,
                                       load_img=args.type_to_eval == 'train')  # only load imgs if eval train set.


    print('Intrinsic: H: {0:4d}, W: {1:4d}, GT focal {2:.2f}.'.format(scene_train.H, scene_train.W, scene_train.focal))

    if args.type_to_eval == 'train':
        scene_eval = scene_train
    else:
        scene_eval = DataLoaderWithCOLMAP(base_dir=args.base_dir,
                                          scene_name=args.scene_name,
                                          data_type='val',
                                          res_ratio=args.resize_ratio,
                                          num_img_to_load=args.eval_img_num,
                                          skip=args.eval_skip,
                                          use_ndc=args.use_ndc)

    '''Model Loading'''
    pos_enc_in_dims = (2 * args.pos_enc_levels + int(args.pos_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    if args.use_dir_enc:
        dir_enc_in_dims = (2 * args.dir_enc_levels + int(args.dir_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    else:
        dir_enc_in_dims = 0

    model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims)
    if args.multi_gpu:
        model = torch.nn.DataParallel(model).to(device=my_devices)
    else:
        model = model.to(device=my_devices)
    model = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_nerf.pth'), model, map_location=my_devices)

    if args.init_focal_from == 'colmap':
        focal_net = LearnFocal(scene_train.H, scene_train.W, args.learn_focal, args.fx_only, order=args.focal_order, init_focal=scene_train.focal)
    else:
        focal_net = LearnFocal(scene_train.H, scene_train.W, args.learn_focal, args.fx_only, order=args.focal_order)
    if args.multi_gpu:
        focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices)
    else:
        focal_net = focal_net.to(device=my_devices)
    # load learned focal if we did not init focal with something
    if args.init_focal_from == 'none':
        focal_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_focal.pth'), focal_net, map_location=my_devices)
    fxfy = focal_net(0)
    if 'blender/' in args.scene_name:
        print('GT: fx {0:.2f} fy {1:.2f}, learned: fx {2:.2f}, fy {3:.2f}, COLMAP: {4:.2f}'.format(
            scene_train.gt_fx, scene_train.gt_fy, fxfy[0].item(), fxfy[1].item(), scene_train.focal))
    else:
        print('COLMAP: {0:.2f}, learned: fx {1:.2f}, fy {2:.2f}, '.format(
            scene_train.focal, fxfy[0].item(), fxfy[1].item()))

    if args.init_pose_from == 'colmap':
        learned_pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, scene_train.c2ws)
    else:
        learned_pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t, None)
    if args.multi_gpu:
        learned_pose_param_net = torch.nn.DataParallel(learned_pose_param_net).to(device=my_devices)
    else:
        learned_pose_param_net = learned_pose_param_net.to(device=my_devices)
    learned_pose_param_net = load_ckpt_to_net(os.path.join(args.ckpt_dir, 'latest_pose.pth'), learned_pose_param_net,
                                              map_location=my_devices)

    # We optimise poses for validation images while freezing learned focal length and trained nerf model.
    # This step is only required when we compute evaluation metrics, as the space of learned poses
    # is different from the space of colmap poses.
    if args.type_to_eval == 'train':
        eval_pose_param_net = learned_pose_param_net
    else:
        with torch.no_grad():
            # compuate a scale between two learned traj and colmap traj
            init_c2ws = scene_eval.c2ws
            learned_c2ws_train = torch.stack([learned_pose_param_net(i) for i in range(scene_train.N_imgs)])  # (N, 4, 4)
            colmap_c2ws_train = scene_train.c2ws  # (N, 4, 4)
            init_c2ws, scale_colmap2est = align_scale_c2b_use_a2b(colmap_c2ws_train, learned_c2ws_train, init_c2ws)

        eval_pose_param_net = LearnPose(scene_eval.N_imgs, args.opt_eval_R, args.opt_eval_t, init_c2ws)
        if args.multi_gpu:
            eval_pose_param_net = torch.nn.DataParallel(eval_pose_param_net).to(device=my_devices)
        else:
            eval_pose_param_net = eval_pose_param_net.to(device=my_devices)

    '''Set Optimiser'''
    optimizer_eval_pose = torch.optim.Adam(eval_pose_param_net.parameters(), lr=args.opt_eval_lr)
    scheduler_eval_pose = torch.optim.lr_scheduler.MultiStepLR(optimizer_eval_pose,
                                                               milestones=args.eval_pose_milestones,
                                                               gamma=args.eval_pose_lr_gamma)

    '''Optimise eval poses'''
    if args.type_to_eval != 'train':
        for epoch_i in tqdm(range(args.opt_pose_epoch), desc='optimising eval'):
            mean_losses = opt_eval_pose_one_epoch(model, focal_net, eval_pose_param_net, scene_eval, optimizer_eval_pose,
                                                  my_devices)
            opt_L2_loss = mean_losses['L2']
            opt_pose_psnr = mse2psnr(opt_L2_loss)
            scheduler_eval_pose.step()

            writer.add_scalar('opt/mse', opt_L2_loss, epoch_i)
            writer.add_scalar('opt/psnr', opt_pose_psnr, epoch_i)

            logger.info('{0:6d} ep: Opt: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(epoch_i, opt_L2_loss, opt_pose_psnr))
            tqdm.write('{0:6d} ep: Opt: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(epoch_i, opt_L2_loss, opt_pose_psnr))

    with torch.no_grad():
        '''Compute ATE'''
        stats_tran, stats_rot, stats_scale = eval_one_epoch_traj(scene_train, learned_pose_param_net)
        print('------------------ ATE statistic ------------------')
        print('Traj Err: translation: {0:.6f}, rotation: {1:.2f} deg, scale: {2:.2f}'.format(stats_tran['mean'],
                                                                                             stats_rot['mean'],
                                                                                             stats_scale['mean']))
        print('-------------------------------------------------')

        '''Final Render'''
        result = eval_one_epoch_img(scene_eval, model, focal_net, eval_pose_param_net, my_devices, args, logger)
        imgs = result['imgs']
        depths = result['depths']

        '''Write to folder'''
        imgs = (imgs.cpu().numpy() * 255).astype(np.uint8)
        depths = (depths.cpu().numpy() * 10).astype(np.uint8)

        for i in range(scene_eval.N_imgs):
            imageio.imwrite(os.path.join(img_out_dir, str(i).zfill(4) + '.png'), imgs[i])
            imageio.imwrite(os.path.join(depth_out_dir, str(i).zfill(4) + '.png'), depths[i])

        imageio.mimwrite(os.path.join(video_out_dir, 'img.mp4'), imgs, fps=30, quality=9)
        imageio.mimwrite(os.path.join(video_out_dir, 'depth.mp4'), depths, fps=30, quality=9)

        imageio.mimwrite(os.path.join(video_out_dir, 'img.gif'), imgs, fps=30)
        imageio.mimwrite(os.path.join(video_out_dir, 'depth.gif'), depths, fps=30)
        return