def cache_nerf_dataset(args): images, poses, render_poses, hwf = ( None, None, None, None, ) i_train, i_val, i_test = None, None, None if args.type == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( args.datapath, half_res=args.blender_half_res, testskip=args.blender_stride) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] elif args.type == "llff": images, poses, bds, render_poses, i_test = load_llff_data( args.datapath, factor=args.llff_downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf hwf = [int(H), int(W), focal] images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Device on which to run. if torch.cuda.is_available(): device = "cuda" else: device = "cpu" os.makedirs(os.path.join(args.savedir, "train"), exist_ok=True) os.makedirs(os.path.join(args.savedir, "val"), exist_ok=True) os.makedirs(os.path.join(args.savedir, "test"), exist_ok=True) np.random.seed(args.randomseed) for img_idx in tqdm(i_train): for j in range(args.num_variations): img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H).to(device), torch.arange(W).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) target_s = None save_path = None if args.sample_all is False: select_inds = np.random.choice(coords.shape[0], size=(args.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] save_path = os.path.join( args.savedir, "train", str(img_idx).zfill(4), str(j).zfill(4), ".data", ) else: target_s = img_target save_path = os.path.join(args.savedir, "train", str(img_idx).zfill(4) + ".data") batch_rays = torch.stack([ray_origins, ray_directions], dim=0) cache_dict = { "height": H, "width": W, "focal_length": focal, "ray_bundle": batch_rays.detach().cpu(), "target": target_s.detach().cpu(), } save_path = os.path.join(args.savedir, "train", str(img_idx).zfill(4) + ".data") torch.save(cache_dict, save_path) if args.sample_all is True: break for img_idx in tqdm(i_val): img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle(H, W, focal, pose_target) cache_dict = { "height": H, "width": W, "focal_length": focal, "ray_origins": ray_origins.detach().cpu(), "ray_directions": ray_directions.detach().cpu(), "target": img_target.detach().cpu(), } save_path = os.path.join(args.savedir, "val", str(img_idx).zfill(4) + ".data") torch.save(cache_dict, save_path)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # If a pre-cached dataset is available, skip the dataloader. USE_CACHED_DATASET = False train_paths, validation_paths = None, None images, poses, render_poses, hwf, i_split = None, None, None, None, None H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None if hasattr(cfg.dataset, "cachedir") and os.path.exists( cfg.dataset.cachedir): train_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "train", "*.data")) validation_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "val", "*.data")) USE_CACHED_DATASET = True else: # Load dataset images, poses, render_poses, hwf = None, None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + ( 1. - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. if torch.cuda.is_available(): device = "cuda" else: device = "cpu" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) if model_fine is not None: trainable_parameters += list(model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters, lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() if model_fine: model_coarse.train() rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(train_paths) cache_dict = torch.load(datafile) ray_bundle = cache_dict["ray_bundle"].to(device) ray_origins, ray_directions = ( ray_bundle[0].reshape((-1, 3)), ray_bundle[1].reshape((-1, 3)), ) target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3)) select_inds = np.random.choice( ray_origins.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False, ) ray_origins, ray_directions = ( ray_origins[select_inds], ray_directions[select_inds], ) target_ray_values = target_ray_values[select_inds].to(device) # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) else: img_idx = np.random.choice(i_train) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H).to(device), torch.arange(W).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) select_inds = np.random.choice( coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] # batch_rays = torch.stack([ray_origins, ray_directions], dim=0) target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] then = time.time() rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = None if rgb_fine is not None: fine_loss = torch.nn.functional.mse_loss( rgb_fine[..., :3], target_ray_values[..., :3]) # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3]) loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0) loss.backward() psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " PSNR: " + str(psnr)) writer.add_scalar("train/loss", loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) if rgb_fine is not None: writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1): tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() if model_fine: model_coarse.eval() start = time.time() with torch.no_grad(): rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(validation_paths) cache_dict = torch.load(datafile) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, cache_dict["ray_origins"].to(device), cache_dict["ray_directions"].to(device), cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = cache_dict["target"].to(device) else: img_idx = np.random.choice(i_val) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validataion/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) if rgb_fine is not None: writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": None if not model_fine else model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, required=True, help="Path to (.yml) config file." ) parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # If a pre-cached dataset is available, skip the dataloader. USE_CACHED_DATASET = False train_paths, validation_paths = None, None images, poses, render_poses, hwf, i_split = None, None, None, None, None H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None if hasattr(cfg.dataset, "cachedir") and os.path.exists(cfg.dataset.cachedir): train_paths = glob.glob(os.path.join(cfg.dataset.cachedir, "train", "*.data")) validation_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "val", "*.data") ) USE_CACHED_DATASET = True else: # Load dataset images, poses, render_poses, hwf = None, None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor ) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[:: cfg.dataset.llffhold] i_val = i_test i_train = np.array( [ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ] ) H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. if torch.cuda.is_available(): device = "cuda" else: device = "cpu" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) if model_fine is not None: trainable_parameters += list(model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)( trainable_parameters, lr=cfg.optimizer.lr ) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). # Load an existing checkpoint, if a path is specified. if os.path.exists(os.path.abspath(configargs.load_checkpoint)): device = torch.device('cuda:0') #device = torch.device('cpu') checkpoint = torch.load(configargs.load_checkpoint, map_location=device) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) dim_xyz = 3 + 2 * 3 * cfg.models.coarse.num_encoding_fn_xyz dim_dir = 3 + 2 * 3 * cfg.models.coarse.num_encoding_fn_dir dummy_input = torch.zeros((1, dim_dir + dim_xyz), dtype=torch.float).to(device) out_folder, _ = ntpath.split(configargs.config) torch.onnx.export(model_coarse, dummy_input, os.path.join(out_folder, "coarse_model.onnx"), verbose=False, input_names=["input"]) torch.onnx.export(model_fine, dummy_input, os.path.join(out_folder, "fine_model.onnx"), input_names=["input"]) else: print("Couldn't find the checkpoint file at {}".format(os.path.abspath(configargs.load_checkpoint))) print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.", ) parser.add_argument("--savedir", type=str, help="Save images to this directory, if specified.") parser.add_argument("--save-disparity-image", action="store_true", help="Save disparity images too.") configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": # Load blender dataset images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) elif cfg.dataset.type.lower() == "llff": # Load LLFF dataset images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) checkpoint = torch.load(configargs.checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] model_coarse.eval() if model_fine: model_fine.eval() render_poses = render_poses.float().to(device) # Create directory to save images to. os.makedirs(configargs.savedir, exist_ok=True) if configargs.save_disparity_image: os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True) # Evaluation loop times_per_image = [] for i, pose in enumerate(tqdm(render_poses)): start = time.time() rgb = None, None disp = None, None with torch.no_grad(): pose = pose[:3, :4] ray_origins, ray_directions = get_ray_bundle( hwf[0], hwf[1], hwf[2], pose) rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _ = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) rgb = rgb_fine if rgb_fine is not None else rgb_coarse if configargs.save_disparity_image: disp = disp_fine if disp_fine is not None else disp_coarse times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") imageio.imwrite( savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) if configargs.save_disparity_image: savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") imageio.imwrite(savefile, cast_to_disparity_image(disp)) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
def main(): # Config options: parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.") parser.add_argument("--dual-render", action='store_true', default=False) parser.add_argument('--gpu-id', type=int, default=0, help="id of the CUDA GPU to use (default: 0)") configargs = parser.parse_args() cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # Dataset: images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Hardware device = "cpu" if torch.cuda.is_available(): torch.cuda.set_device(configargs.gpu_id) device = "cuda" # Model encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir) model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs) model_coarse.to(device) model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs) model_fine.to(device) # Load checkpoint checkpoint = torch.load(configargs.checkpoint, map_location=device) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] # Prepare model and data model_coarse.eval() if model_fine: model_fine.eval() render_poses = render_poses.float().to(device) print("Dual render?", configargs.dual_render) # Evaluation loop with torch.no_grad(): fine_psnrs = [] if type(i_test) != list: i_test = [i_test] for i in i_test: print( f"Test sample {i + 1 - i_test[0]}/{i_test[-1] - i_test[0]}...") img_target = images[i].to(device) pose_target = poses[i, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, dual_render=configargs.dual_render) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) loss, fine_loss = 0.0, 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = fine_loss else: loss = coarse_loss loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) psnr_coarse = mse2psnr(coarse_loss) psnr_fine = mse2psnr(fine_loss) print( f"\t Loss at sample: {psnr} (f:{psnr_fine}, c:{psnr_coarse})") fine_psnrs.append(psnr_fine) print(f"Validation PSNR: {sum(fine_psnrs) / len(fine_psnrs)}.")