def loss_fn(f_rng, cp, fp, image_id): H, W, focal = ( intrinsics["train"].height, intrinsics["train"].width, intrinsics["train"].focal_length, ) ray_origins, ray_directions, target_s = sampler( images["train"][image_id], poses["train"][image_id], intrinsics["train"], f_rng[0], config.dataset.sampler, ) _, rendered_images = run_one_iter_of_nerf( H, W, focal, functools.partial(model_coarse.apply, cp), functools.partial(model_fine.apply, fp), ray_origins, ray_directions, config.nerf.train, config.nerf.model, config.dataset.projection, f_rng[1], False, ) rgb_coarse, _, _, rgb_fine, _, _ = ( rendered_images[..., :3], rendered_images[..., 3:4], rendered_images[..., 4:5], rendered_images[..., 5:8], rendered_images[..., 8:9], rendered_images[..., 9:10], ) coarse_loss = jnp.mean( ((target_s[..., :3] - rgb_coarse)**2.0).flatten()) loss = coarse_loss if config.nerf.train.num_fine > 0: fine_loss = jnp.mean( ((target_s[..., :3] - rgb_fine)**2.0).flatten()) loss = loss + fine_loss return loss, Losses(coarse_loss=coarse_loss, fine_loss=fine_loss)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # Load dataset if cfg.dataset.type.lower() == "llff": images, poses, _, _, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf H, W = int(H), int(W) images = torch.from_numpy(images) poses = torch.from_numpy(poses) USE_HR_LR = False # Load LR images if hasattr(cfg.dataset, "relative_lr_factor"): assert hasattr(cfg.dataset, "hr_fps") assert hasattr(cfg.dataset, "hr_frequency") USE_HR_LR = True images_lr, poses_lr, _, _, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor * cfg.dataset.relative_lr_factor) hwf = poses_lr[0, :3, -1] poses_lr = poses_lr[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images_lr.shape[0])[::cfg.dataset.llffhold] i_train_lr = np.array([ i for i in np.arange(images_lr.shape[0]) if (i not in i_test) ]) H_lr, W_lr, focal_lr = hwf H_lr, W_lr = int(H_lr), int(W_lr) images_lr = torch.from_numpy(images_lr) poses_lr = torch.from_numpy(poses_lr) # Expose only some HR images i_train = i_train[::cfg.dataset.hr_fps] print(f'LR summary: N={i_train_lr.shape[0]}, ' f'resolution={H_lr}x{W_lr}, focal-length={focal_lr}') print(f'HR summary: N={i_train.shape[0]}, ' f'resolution={H}x{W}, focal-length={focal}') # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. device = f"cuda:{cfg.experiment.gpu_id}" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse and fine resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) + list( model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters, lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() model_fine.train() if USE_HR_LR and i % cfg.dataset.hr_frequency != 0: H_iter, W_iter, focal_iter = H_lr, W_lr, focal_lr i_train_iter = i_train_lr images_iter, poses_iter = images_lr, poses_lr else: H_iter, W_iter, focal_iter = H, W, focal i_train_iter = i_train images_iter, poses_iter = images, poses img_idx = np.random.choice(i_train_iter) img_target = images_iter[img_idx].to(device) pose_target = poses_iter[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle(H_iter, W_iter, focal_iter, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H_iter).to(device), torch.arange(W_iter).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) select_inds = np.random.choice(coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H_iter, W_iter, focal_iter, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = torch.nn.functional.mse_loss(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss loss.backward() psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " PSNR: " + str(psnr)) writer.add_scalar("train/loss", loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1): tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() model_fine.eval() start = time.time() with torch.no_grad(): img_idx = np.random.choice(i_val) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_scalar("validataion/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.", ) parser.add_argument("--savedir", type=str, default='./renders/', help="Save images to this directory, if specified.") parser.add_argument("--save-disparity-image", action="store_true", help="Save disparity images too.") parser.add_argument("--save-error-image", action="store_true", help="Save photometric error visualization") configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": # Load blender dataset images, poses, render_poses, hwf, i_split, expressions, _, _ = load_flame_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, test=True) #i_train, i_val, i_test = i_split i_test = i_split H, W, focal = hwf H, W = int(H), int(W) elif cfg.dataset.type.lower() == "llff": # Load LLFF dataset images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_fine.to(device) checkpoint = torch.load(configargs.checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] if "background" in checkpoint.keys(): background = checkpoint["background"] if background is not None: print("loaded background with shape ", background.shape) background.to(device) if "latent_codes" in checkpoint.keys(): latent_codes = checkpoint["latent_codes"] use_latent_code = False if latent_codes is not None: use_latent_code = True latent_codes.to(device) print("loading index map for latent codes...") idx_map = np.load(cfg.dataset.basedir + "/index_map.npy").astype(int) print("loaded latent codes from checkpoint, with shape ", latent_codes.shape) model_coarse.eval() if model_fine: model_fine.eval() replace_background = True if replace_background: from PIL import Image #background = Image.open('./view.png') background = Image.open(cfg.dataset.basedir + '/bg/00050.png') #background = Image.open("./real_data/andrei_dvp/" + '/bg/00050.png') background.thumbnail((H, W)) background = torch.from_numpy( np.array(background).astype(float)).to(device) background = background / 255 print('loaded custom background of shape', background.shape) #background = torch.ones_like(background) #background.permute(2,0,1) render_poses = render_poses.float().to(device) # Create directory to save images to. os.makedirs(configargs.savedir, exist_ok=True) if configargs.save_disparity_image: os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True) if configargs.save_error_image: os.makedirs(os.path.join(configargs.savedir, "error"), exist_ok=True) os.makedirs(os.path.join(configargs.savedir, "normals"), exist_ok=True) # Evaluation loop times_per_image = [] #render_poses = render_poses.float().to(device) render_poses = poses[i_test].float().to(device) #expressions = torch.arange(-6,6,0.5).float().to(device) render_expressions = expressions[i_test].float().to(device) #avg_img = torch.mean(images[i_train],axis=0) #avg_img = torch.ones_like(avg_img) #pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) #for i, pose in enumerate(tqdm(render_poses)): index_of_image_after_train_shuffle = 0 # render_expressions = render_expressions[[300]] ### TODO render specific expression ####################### no_background = False no_expressions = False no_lcode = False nerf = False frontalize = False interpolate_mouth = False ####################### if nerf: no_background = True no_expressions = True no_lcode = True if no_background: background = None if no_expressions: render_expressions = torch.zeros_like(render_expressions, device=render_expressions.device) if no_lcode: use_latent_code = True latent_codes = torch.zeros(5000, 32, device=device) for i, expression in enumerate(tqdm(render_expressions)): #for i in range(75,151): #if i%25 != 0: ### TODO generate only every 25th im #if i != 511: ### TODO generate only every 25th im # continue start = time.time() rgb = None, None disp = None, None with torch.no_grad(): pose = render_poses[i] if interpolate_mouth: frame_id = 241 num_images = 150 pose = render_poses[241] expression = render_expressions[241].clone() expression[68] = torch.arange(-1, 1, 2 / 150, device=device)[i] if frontalize: pose = render_poses[0] #pose = render_poses[300] ### TODO fixes pose #expression = render_expressions[0] ### TODO fixes expr #expression = torch.zeros_like(expression).to(device) ablate = 'view_dir' if ablate == 'expression': pose = render_poses[100] elif ablate == 'latent_code': pose = render_poses[100] expression = render_expressions[100] if idx_map[100 + i, 1] >= 0: #print("found latent code for this image") index_of_image_after_train_shuffle = idx_map[100 + i, 1] elif ablate == 'view_dir': pose = render_poses[100] expression = render_expressions[100] _, ray_directions_ablation = get_ray_bundle( hwf[0], hwf[1], hwf[2], render_poses[240 + i][:3, :4]) pose = pose[:3, :4] #pose = torch.from_numpy(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])) if use_latent_code: if idx_map[i, 1] >= 0: #print("found latent code for this image") index_of_image_after_train_shuffle = idx_map[i, 1] #index_of_image_after_train_shuffle = 10 ## TODO Fixes latent code #index_of_image_after_train_shuffle = idx_map[84,1] ## TODO Fixes latent code v2 for andrei index_of_image_after_train_shuffle = idx_map[ 10, 1] ## TODO Fixes latent code - USE THIS if not ablating! latent_code = latent_codes[index_of_image_after_train_shuffle].to( device) if use_latent_code else None #latent_code = torch.mean(latent_codes) ray_origins, ray_directions = get_ray_bundle( hwf[0], hwf[1], hwf[2], pose) rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _, weights = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression, background_prior=background.view(-1, 3) if (background is not None) else None, #background_prior = torch.ones_like(background).view(-1,3), # White background latent_code=latent_code, ray_directions_ablation=ray_directions_ablation) rgb = rgb_fine if rgb_fine is not None else rgb_coarse normals = torch_normal_map(disp_fine, focal, weights, clean=True) #normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy()) save_plt_image( normals.cpu().numpy().astype('uint8'), os.path.join(configargs.savedir, 'normals', f"{i:04d}.png")) #if configargs.save_disparity_image: if False: disp = disp_fine if disp_fine is not None else disp_coarse #normals = normal_map_from_depth_map_backproject(disp.cpu().numpy()) normals = normal_map_from_depth_map_backproject( disp_fine.cpu().numpy()) save_plt_image( normals.astype('uint8'), os.path.join(configargs.savedir, 'normals', f"{i:04d}.png")) #if configargs.save_normal_image: # normal_map_from_depth_map_backproject(disp_fine.cpu().numpy()) #rgb[torch.where(weights>0.25)]=1.0 #rgb[torch.where(weights>0.1)] = (rgb * weights + (torch.ones_like(weights)-weights)*torch.ones_like(weights)) times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") imageio.imwrite( savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) if configargs.save_disparity_image: savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") imageio.imwrite(savefile, cast_to_disparity_image(disp_fine)) if configargs.save_error_image: savefile = os.path.join(configargs.savedir, "error", f"{i:04d}.png") GT = images[i_test][i] fig = error_image(GT, rgb.cpu().numpy()) #imageio.imwrite(savefile, cast_to_disparity_image(disp)) plt.savefig(savefile, pad_inches=0, bbox_inches='tight', dpi=54) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
def main(): # Config options: parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.") parser.add_argument("--dual-render", action='store_true', default=False) parser.add_argument('--gpu-id', type=int, default=0, help="id of the CUDA GPU to use (default: 0)") configargs = parser.parse_args() cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # Dataset: images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Hardware device = "cpu" if torch.cuda.is_available(): torch.cuda.set_device(configargs.gpu_id) device = "cuda" # Model encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir) model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs) model_coarse.to(device) model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs) model_fine.to(device) # Load checkpoint checkpoint = torch.load(configargs.checkpoint, map_location=device) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] # Prepare model and data model_coarse.eval() if model_fine: model_fine.eval() render_poses = render_poses.float().to(device) print("Dual render?", configargs.dual_render) # Evaluation loop with torch.no_grad(): fine_psnrs = [] if type(i_test) != list: i_test = [i_test] for i in i_test: print( f"Test sample {i + 1 - i_test[0]}/{i_test[-1] - i_test[0]}...") img_target = images[i].to(device) pose_target = poses[i, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, dual_render=configargs.dual_render) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) loss, fine_loss = 0.0, 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = fine_loss else: loss = coarse_loss loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) psnr_coarse = mse2psnr(coarse_loss) psnr_fine = mse2psnr(fine_loss) print( f"\t Loss at sample: {psnr} (f:{psnr_fine}, c:{psnr_coarse})") fine_psnrs.append(psnr_fine) print(f"Validation PSNR: {sum(fine_psnrs) / len(fine_psnrs)}.")