def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--base-dir", type=str, required=False, help="Override the default base dir.", ) parser.add_argument( "--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.", ) parser.add_argument("--save-dir", type=str, help="Save mesh to this directory, if specified.") parser.add_argument("--iso-level", type=float, help="Iso-Level to be queried", default=32) parser.add_argument('--cache-mesh', dest='cache_mesh', action='store_true') parser.add_argument('--no-cache-mesh', dest='cache_mesh', action='store_false') parser.set_defaults(cache_mesh=True) config_args = parser.parse_args() # Read config file. cfg = None with open(config_args.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) checkpoint = torch.load(config_args.checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") model_coarse.eval() if model_fine: model_fine.eval() export_ray_trace(model_coarse, model_fine, config_args, cfg, encode_position_fn, encode_direction_fn, device)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.", ) parser.add_argument("--savedir", type=str, help="Save images to this directory, if specified.") parser.add_argument("--save-disparity-image", action="store_true", help="Save disparity images too.") configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": # Load blender dataset images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) elif cfg.dataset.type.lower() == "llff": # Load LLFF dataset images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) checkpoint = torch.load(configargs.checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] model_coarse.eval() if model_fine: model_fine.eval() render_poses = render_poses.float().to(device) # Create directory to save images to. os.makedirs(configargs.savedir, exist_ok=True) if configargs.save_disparity_image: os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True) # Evaluation loop times_per_image = [] for i, pose in enumerate(tqdm(render_poses)): start = time.time() rgb = None, None disp = None, None with torch.no_grad(): pose = pose[:3, :4] ray_origins, ray_directions = get_ray_bundle( hwf[0], hwf[1], hwf[2], pose) rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _ = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) rgb = rgb_fine if rgb_fine is not None else rgb_coarse if configargs.save_disparity_image: disp = disp_fine if disp_fine is not None else disp_coarse times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") imageio.imwrite( savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) if configargs.save_disparity_image: savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") imageio.imwrite(savefile, cast_to_disparity_image(disp)) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # If a pre-cached dataset is available, skip the dataloader. USE_CACHED_DATASET = False train_paths, validation_paths = None, None images, poses, render_poses, hwf, i_split = None, None, None, None, None H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None if hasattr(cfg.dataset, "cachedir") and os.path.exists( cfg.dataset.cachedir): train_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "train", "*.data")) validation_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "val", "*.data")) USE_CACHED_DATASET = True else: # Load dataset images, poses, render_poses, hwf = None, None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + ( 1. - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. if torch.cuda.is_available(): device = "cuda" else: device = "cpu" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) if model_fine is not None: trainable_parameters += list(model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters, lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() if model_fine: model_coarse.train() rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(train_paths) cache_dict = torch.load(datafile) ray_bundle = cache_dict["ray_bundle"].to(device) ray_origins, ray_directions = ( ray_bundle[0].reshape((-1, 3)), ray_bundle[1].reshape((-1, 3)), ) target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3)) select_inds = np.random.choice( ray_origins.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False, ) ray_origins, ray_directions = ( ray_origins[select_inds], ray_directions[select_inds], ) target_ray_values = target_ray_values[select_inds].to(device) # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) else: img_idx = np.random.choice(i_train) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H).to(device), torch.arange(W).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) select_inds = np.random.choice( coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] # batch_rays = torch.stack([ray_origins, ray_directions], dim=0) target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] then = time.time() rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = None if rgb_fine is not None: fine_loss = torch.nn.functional.mse_loss( rgb_fine[..., :3], target_ray_values[..., :3]) # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3]) loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0) loss.backward() psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " PSNR: " + str(psnr)) writer.add_scalar("train/loss", loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) if rgb_fine is not None: writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1): tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() if model_fine: model_coarse.eval() start = time.time() with torch.no_grad(): rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(validation_paths) cache_dict = torch.load(datafile) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, cache_dict["ray_origins"].to(device), cache_dict["ray_directions"].to(device), cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = cache_dict["target"].to(device) else: img_idx = np.random.choice(i_val) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validataion/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) if rgb_fine is not None: writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": None if not model_fine else model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, required=True, help="Path to (.yml) config file." ) parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # If a pre-cached dataset is available, skip the dataloader. USE_CACHED_DATASET = False train_paths, validation_paths = None, None images, poses, render_poses, hwf, i_split = None, None, None, None, None H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None if hasattr(cfg.dataset, "cachedir") and os.path.exists(cfg.dataset.cachedir): train_paths = glob.glob(os.path.join(cfg.dataset.cachedir, "train", "*.data")) validation_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "val", "*.data") ) USE_CACHED_DATASET = True else: # Load dataset images, poses, render_poses, hwf = None, None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor ) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[:: cfg.dataset.llffhold] i_val = i_test i_train = np.array( [ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ] ) H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. if torch.cuda.is_available(): device = "cuda" else: device = "cpu" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) if model_fine is not None: trainable_parameters += list(model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)( trainable_parameters, lr=cfg.optimizer.lr ) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). # Load an existing checkpoint, if a path is specified. if os.path.exists(os.path.abspath(configargs.load_checkpoint)): device = torch.device('cuda:0') #device = torch.device('cpu') checkpoint = torch.load(configargs.load_checkpoint, map_location=device) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) dim_xyz = 3 + 2 * 3 * cfg.models.coarse.num_encoding_fn_xyz dim_dir = 3 + 2 * 3 * cfg.models.coarse.num_encoding_fn_dir dummy_input = torch.zeros((1, dim_dir + dim_xyz), dtype=torch.float).to(device) out_folder, _ = ntpath.split(configargs.config) torch.onnx.export(model_coarse, dummy_input, os.path.join(out_folder, "coarse_model.onnx"), verbose=False, input_names=["input"]) torch.onnx.export(model_fine, dummy_input, os.path.join(out_folder, "fine_model.onnx"), input_names=["input"]) else: print("Couldn't find the checkpoint file at {}".format(os.path.abspath(configargs.load_checkpoint))) print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # If a pre-cached dataset is available, skip the dataloader. USE_CACHED_DATASET = False train_paths, validation_paths = None, None images, poses, render_poses, hwf, i_split, expressions = None, None, None, None, None, None H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None if hasattr(cfg.dataset, "cachedir") and os.path.exists( cfg.dataset.cachedir): train_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "train", "*.data")) validation_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "val", "*.data")) USE_CACHED_DATASET = True else: # Load dataset images, poses, render_poses, hwf, expressions = None, None, None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split, expressions, _, bboxs = load_flame_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + ( 1.0 - images[..., -1:]) print("done loading data") # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. if torch.cuda.is_available(): device = "cuda" #+ ":" + str(cfg.experiment.device) else: device = "cpu" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_fine.to(device) ################################### ################################### train_background = False supervised_train_background = False blur_background = False train_latent_codes = True disable_expressions = False # True to disable expressions disable_latent_codes = False # True to disable latent codes fixed_background = True # Do False to disable BG regularize_latent_codes = True # True to add latent code LOSS, false for most experiments ################################### ################################### supervised_train_background = train_background and supervised_train_background # Avg background #images[i_train] if train_background: with torch.no_grad(): avg_img = torch.mean(images[i_train], axis=0) # Blur Background: if blur_background: avg_img = avg_img.permute(2, 0, 1) avg_img = avg_img.unsqueeze(0) smoother = GaussianSmoothing(channels=3, kernel_size=11, sigma=11) print("smoothed background initialization. shape ", avg_img.shape) avg_img = smoother(avg_img).squeeze(0).permute(1, 2, 0) #avg_img = torch.zeros(H,W,3) #avg_img = torch.rand(H,W,3) #avg_img = 0.5*(torch.rand(H,W,3) + torch.mean(images[i_train],axis=0)) background = torch.tensor(avg_img, device=device) background.requires_grad = True if fixed_background: # load GT background print("loading GT background to condition on") from PIL import Image background = Image.open( os.path.join(cfg.dataset.basedir, 'bg', '00050.png')) background.thumbnail((H, W)) background = torch.from_numpy(np.array(background).astype( np.float32)).to(device) background = background / 255 print("bg shape", background.shape) print("should be ", images[i_train][0].shape) assert background.shape == images[i_train][0].shape else: background = None # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) if model_fine is not None: trainable_parameters += list(model_fine.parameters()) if train_background: #background.requires_grad = True #trainable_parameters.append(background) # add it later when init optimizer for different lr print("background.is_leaf ", background.is_leaf, background.device) if train_latent_codes: latent_codes = torch.zeros(len(i_train), 32, device=device) print("initialized latent codes with shape %d X %d" % (latent_codes.shape[0], latent_codes.shape[1])) if not disable_latent_codes: trainable_parameters.append(latent_codes) latent_codes.requires_grad = True if train_background: optimizer = getattr(torch.optim, cfg.optimizer.type)([{ 'params': trainable_parameters }, { 'params': background, 'lr': cfg.optimizer.lr }], lr=cfg.optimizer.lr) else: optimizer = getattr(torch.optim, cfg.optimizer.type)( [{ 'params': trainable_parameters }, { 'params': background, 'lr': cfg.optimizer.lr }], # this is obsolete but need for continuing training lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) if checkpoint["background"] is not None: print("loaded bg from checkpoint") background = torch.nn.Parameter( checkpoint['background'].to(device)) if checkpoint["latent_codes"] is not None: print("loaded latent codes from checkpoint") latent_codes = torch.nn.Parameter( checkpoint['latent_codes'].to(device)) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays # Prepare importance sampling maps ray_importance_sampling_maps = [] p = 0.9 print("computing boundix boxes probability maps") for i in i_train: bbox = bboxs[i] probs = np.zeros((H, W)) probs.fill(1 - p) probs[bbox[0]:bbox[1], bbox[2]:bbox[3]] = p probs = (1 / probs.sum()) * probs ray_importance_sampling_maps.append(probs.reshape(-1)) print("Starting loop") for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() if model_fine: model_coarse.train() rgb_coarse, rgb_fine = None, None target_ray_values = None background_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(train_paths) cache_dict = torch.load(datafile) ray_bundle = cache_dict["ray_bundle"].to(device) ray_origins, ray_directions = ( ray_bundle[0].reshape((-1, 3)), ray_bundle[1].reshape((-1, 3)), ) target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3)) select_inds = np.random.choice( ray_origins.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False, ) ray_origins, ray_directions = ( ray_origins[select_inds], ray_directions[select_inds], ) target_ray_values = target_ray_values[select_inds].to(device) #target_ray_values = target_ray_values[select_inds].to(device) # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expressions) else: img_idx = np.random.choice(i_train) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) if not disable_expressions: expression_target = expressions[img_idx].to(device) # vector else: # zero expr expression_target = torch.zeros(76, device=device) #bbox = bboxs[img_idx] if not disable_latent_codes: latent_code = latent_codes[img_idx].to( device) if train_latent_codes else None else: latent_codes = torch.zeros(32, device=device) #latent_code = torch.zeros(32).to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H).to(device), torch.arange(W).to(device)), dim=-1, ) # Only randomly choose rays that are in the bounding box ! # coords = torch.stack( # meshgrid_xy(torch.arange(bbox[0],bbox[1]).to(device), torch.arange(bbox[2],bbox[3]).to(device)), # dim=-1, # ) coords = coords.reshape((-1, 2)) # select_inds = np.random.choice( # coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False # ) # Use importance sampling to sample mainly in the bbox with prob p select_inds = np.random.choice( coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False, p=ray_importance_sampling_maps[img_idx]) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] #dump_rays(ray_origins, ray_directions) # batch_rays = torch.stack([ray_origins, ray_directions], dim=0) target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] background_ray_values = background[select_inds[:, 0], select_inds[:, 1], :] if ( train_background or fixed_background) else None #if i<10000: # background_ray_values = None #background_ray_values = None then = time.time() rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression_target, background_prior=background_ray_values, latent_code=latent_code if not disable_latent_codes else torch.zeros(32, device=device)) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = None if rgb_fine is not None: fine_loss = torch.nn.functional.mse_loss( rgb_fine[..., :3], target_ray_values[..., :3]) # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3]) loss = 0.0 # if fine_loss is not None: # loss = fine_loss # else: # loss = coarse_loss latent_code_loss = torch.zeros(1, device=device) if train_latent_codes and not disable_latent_codes: latent_code_loss = torch.norm(latent_code) * 0.0005 #latent_code_loss = torch.zeros(1) background_loss = torch.zeros(1, device=device) if supervised_train_background: background_loss = torch.nn.functional.mse_loss( background_ray_values[..., :3], target_ray_values[..., :3], reduction='none').sum(1) background_loss = torch.mean(background_loss * weights) * 0.001 loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0) psnr = mse2psnr(loss.item()) #loss_total = loss #+ (latent_code_loss if latent_code_loss is not None else 0.0) loss = loss + (latent_code_loss * 10 if regularize_latent_codes else 0.0) loss_total = loss + (background_loss if supervised_train_background is not None else 0.0) #loss.backward() loss_total.backward() #psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " BG Loss: " + str(background_loss.item()) + " PSNR: " + str(psnr) + " LatentReg: " + str(latent_code_loss.item())) #writer.add_scalar("train/loss", loss.item(), i) if train_latent_codes: writer.add_scalar("train/code_loss", latent_code_loss.item(), i) if supervised_train_background: writer.add_scalar("train/bg_loss", background_loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) if rgb_fine is not None: writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1 and False): #torch.cuda.empty_cache() tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() if model_fine: model_coarse.eval() start = time.time() with torch.no_grad(): rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(validation_paths) cache_dict = torch.load(datafile) rgb_coarse, _, _, rgb_fine, _, weights = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, cache_dict["ray_origins"].to(device), cache_dict["ray_directions"].to(device), cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression_target, latent_code=torch.zeros(32, device=device)) target_ray_values = cache_dict["target"].to(device) else: # Do all validation set... loss = 0 for img_idx in i_val[:2]: img_target = images[img_idx].to(device) #tqdm.set_description('val im %d' % img_idx) #tqdm.refresh() # to show immediately the update # # save val image for debug ### DEBUG #### # #GT = target_ray_values[..., :3] # import PIL.Image # #img = GT.permute(2, 0, 1) # # Conver to PIL Image and then np.array (output shape: (H, W, 3)) # #im_numpy = img_target.detach().cpu().numpy() # #im_numpy = np.array(torchvision.transforms.ToPILImage()(img_target.detach().cpu())) # # # im = PIL.Image.fromarray(im_numpy) # im = img_target # im = im.permute(2, 0, 1) # img = np.array(torchvision.transforms.ToPILImage()(im.detach().cpu())) # im = PIL.Image.fromarray(img) # im.save('val_im_target_debug.png') # ### DEBUG #### END pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression_target, background_prior=background.view(-1, 3) if (train_background or fixed_background) else None, latent_code=torch.zeros(32).to(device) if train_latent_codes or disable_latent_codes else None, ) #print("did one val") target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) curr_loss, curr_fine_loss = 0.0, 0.0 if rgb_fine is not None: curr_fine_loss = img2mse( rgb_fine[..., :3], target_ray_values[..., :3]) curr_loss = curr_fine_loss else: curr_loss = coarse_loss loss += curr_loss + curr_fine_loss loss /= len(i_val) psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validation/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) if rgb_fine is not None: writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) if train_background or fixed_background: writer.add_image("validation/background", cast_to_image(background[..., :3]), i) writer.add_image("validation/weights", (weights.detach().cpu().numpy()), i, dataformats='HW') tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) #gpu_profile(frame=sys._getframe(), event='line', arg=None) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": None if not model_fine else model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, "background": None if not (train_background or fixed_background) else background.data, "latent_codes": None if not train_latent_codes else latent_codes.data } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.", ) parser.add_argument("--savedir", type=str, default='./renders/', help="Save images to this directory, if specified.") parser.add_argument("--save-disparity-image", action="store_true", help="Save disparity images too.") parser.add_argument("--save-error-image", action="store_true", help="Save photometric error visualization") configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": # Load blender dataset images, poses, render_poses, hwf, i_split, expressions, _, _ = load_flame_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, test=True) #i_train, i_val, i_test = i_split i_test = i_split H, W, focal = hwf H, W = int(H), int(W) elif cfg.dataset.type.lower() == "llff": # Load LLFF dataset images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_fine.to(device) checkpoint = torch.load(configargs.checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] if "background" in checkpoint.keys(): background = checkpoint["background"] if background is not None: print("loaded background with shape ", background.shape) background.to(device) if "latent_codes" in checkpoint.keys(): latent_codes = checkpoint["latent_codes"] use_latent_code = False if latent_codes is not None: use_latent_code = True latent_codes.to(device) print("loading index map for latent codes...") idx_map = np.load(cfg.dataset.basedir + "/index_map.npy").astype(int) print("loaded latent codes from checkpoint, with shape ", latent_codes.shape) model_coarse.eval() if model_fine: model_fine.eval() replace_background = True if replace_background: from PIL import Image #background = Image.open('./view.png') background = Image.open(cfg.dataset.basedir + '/bg/00050.png') #background = Image.open("./real_data/andrei_dvp/" + '/bg/00050.png') background.thumbnail((H, W)) background = torch.from_numpy( np.array(background).astype(float)).to(device) background = background / 255 print('loaded custom background of shape', background.shape) #background = torch.ones_like(background) #background.permute(2,0,1) render_poses = render_poses.float().to(device) # Create directory to save images to. os.makedirs(configargs.savedir, exist_ok=True) if configargs.save_disparity_image: os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True) if configargs.save_error_image: os.makedirs(os.path.join(configargs.savedir, "error"), exist_ok=True) os.makedirs(os.path.join(configargs.savedir, "normals"), exist_ok=True) # Evaluation loop times_per_image = [] #render_poses = render_poses.float().to(device) render_poses = poses[i_test].float().to(device) #expressions = torch.arange(-6,6,0.5).float().to(device) render_expressions = expressions[i_test].float().to(device) #avg_img = torch.mean(images[i_train],axis=0) #avg_img = torch.ones_like(avg_img) #pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) #for i, pose in enumerate(tqdm(render_poses)): index_of_image_after_train_shuffle = 0 # render_expressions = render_expressions[[300]] ### TODO render specific expression ####################### no_background = False no_expressions = False no_lcode = False nerf = False frontalize = False interpolate_mouth = False ####################### if nerf: no_background = True no_expressions = True no_lcode = True if no_background: background = None if no_expressions: render_expressions = torch.zeros_like(render_expressions, device=render_expressions.device) if no_lcode: use_latent_code = True latent_codes = torch.zeros(5000, 32, device=device) for i, expression in enumerate(tqdm(render_expressions)): #for i in range(75,151): #if i%25 != 0: ### TODO generate only every 25th im #if i != 511: ### TODO generate only every 25th im # continue start = time.time() rgb = None, None disp = None, None with torch.no_grad(): pose = render_poses[i] if interpolate_mouth: frame_id = 241 num_images = 150 pose = render_poses[241] expression = render_expressions[241].clone() expression[68] = torch.arange(-1, 1, 2 / 150, device=device)[i] if frontalize: pose = render_poses[0] #pose = render_poses[300] ### TODO fixes pose #expression = render_expressions[0] ### TODO fixes expr #expression = torch.zeros_like(expression).to(device) ablate = 'view_dir' if ablate == 'expression': pose = render_poses[100] elif ablate == 'latent_code': pose = render_poses[100] expression = render_expressions[100] if idx_map[100 + i, 1] >= 0: #print("found latent code for this image") index_of_image_after_train_shuffle = idx_map[100 + i, 1] elif ablate == 'view_dir': pose = render_poses[100] expression = render_expressions[100] _, ray_directions_ablation = get_ray_bundle( hwf[0], hwf[1], hwf[2], render_poses[240 + i][:3, :4]) pose = pose[:3, :4] #pose = torch.from_numpy(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])) if use_latent_code: if idx_map[i, 1] >= 0: #print("found latent code for this image") index_of_image_after_train_shuffle = idx_map[i, 1] #index_of_image_after_train_shuffle = 10 ## TODO Fixes latent code #index_of_image_after_train_shuffle = idx_map[84,1] ## TODO Fixes latent code v2 for andrei index_of_image_after_train_shuffle = idx_map[ 10, 1] ## TODO Fixes latent code - USE THIS if not ablating! latent_code = latent_codes[index_of_image_after_train_shuffle].to( device) if use_latent_code else None #latent_code = torch.mean(latent_codes) ray_origins, ray_directions = get_ray_bundle( hwf[0], hwf[1], hwf[2], pose) rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _, weights = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression, background_prior=background.view(-1, 3) if (background is not None) else None, #background_prior = torch.ones_like(background).view(-1,3), # White background latent_code=latent_code, ray_directions_ablation=ray_directions_ablation) rgb = rgb_fine if rgb_fine is not None else rgb_coarse normals = torch_normal_map(disp_fine, focal, weights, clean=True) #normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy()) save_plt_image( normals.cpu().numpy().astype('uint8'), os.path.join(configargs.savedir, 'normals', f"{i:04d}.png")) #if configargs.save_disparity_image: if False: disp = disp_fine if disp_fine is not None else disp_coarse #normals = normal_map_from_depth_map_backproject(disp.cpu().numpy()) normals = normal_map_from_depth_map_backproject( disp_fine.cpu().numpy()) save_plt_image( normals.astype('uint8'), os.path.join(configargs.savedir, 'normals', f"{i:04d}.png")) #if configargs.save_normal_image: # normal_map_from_depth_map_backproject(disp_fine.cpu().numpy()) #rgb[torch.where(weights>0.25)]=1.0 #rgb[torch.where(weights>0.1)] = (rgb * weights + (torch.ones_like(weights)-weights)*torch.ones_like(weights)) times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") imageio.imwrite( savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) if configargs.save_disparity_image: savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") imageio.imwrite(savefile, cast_to_disparity_image(disp_fine)) if configargs.save_error_image: savefile = os.path.join(configargs.savedir, "error", f"{i:04d}.png") GT = images[i_test][i] fig = error_image(GT, rgb.cpu().numpy()) #imageio.imwrite(savefile, cast_to_disparity_image(disp)) plt.savefig(savefile, pad_inches=0, bbox_inches='tight', dpi=54) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # Load dataset if cfg.dataset.type.lower() == "llff": images, poses, _, _, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf H, W = int(H), int(W) images = torch.from_numpy(images) poses = torch.from_numpy(poses) USE_HR_LR = False # Load LR images if hasattr(cfg.dataset, "relative_lr_factor"): assert hasattr(cfg.dataset, "hr_fps") assert hasattr(cfg.dataset, "hr_frequency") USE_HR_LR = True images_lr, poses_lr, _, _, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor * cfg.dataset.relative_lr_factor) hwf = poses_lr[0, :3, -1] poses_lr = poses_lr[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images_lr.shape[0])[::cfg.dataset.llffhold] i_train_lr = np.array([ i for i in np.arange(images_lr.shape[0]) if (i not in i_test) ]) H_lr, W_lr, focal_lr = hwf H_lr, W_lr = int(H_lr), int(W_lr) images_lr = torch.from_numpy(images_lr) poses_lr = torch.from_numpy(poses_lr) # Expose only some HR images i_train = i_train[::cfg.dataset.hr_fps] print(f'LR summary: N={i_train_lr.shape[0]}, ' f'resolution={H_lr}x{W_lr}, focal-length={focal_lr}') print(f'HR summary: N={i_train.shape[0]}, ' f'resolution={H}x{W}, focal-length={focal}') # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. device = f"cuda:{cfg.experiment.gpu_id}" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse and fine resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) + list( model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters, lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() model_fine.train() if USE_HR_LR and i % cfg.dataset.hr_frequency != 0: H_iter, W_iter, focal_iter = H_lr, W_lr, focal_lr i_train_iter = i_train_lr images_iter, poses_iter = images_lr, poses_lr else: H_iter, W_iter, focal_iter = H, W, focal i_train_iter = i_train images_iter, poses_iter = images, poses img_idx = np.random.choice(i_train_iter) img_target = images_iter[img_idx].to(device) pose_target = poses_iter[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle(H_iter, W_iter, focal_iter, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H_iter).to(device), torch.arange(W_iter).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) select_inds = np.random.choice(coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H_iter, W_iter, focal_iter, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = torch.nn.functional.mse_loss(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss loss.backward() psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " PSNR: " + str(psnr)) writer.add_scalar("train/loss", loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1): tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() model_fine.eval() start = time.time() with torch.no_grad(): img_idx = np.random.choice(i_val) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_scalar("validataion/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): # Config options: parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.") parser.add_argument("--dual-render", action='store_true', default=False) parser.add_argument('--gpu-id', type=int, default=0, help="id of the CUDA GPU to use (default: 0)") configargs = parser.parse_args() cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # Dataset: images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Hardware device = "cpu" if torch.cuda.is_available(): torch.cuda.set_device(configargs.gpu_id) device = "cuda" # Model encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir) model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs) model_coarse.to(device) model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs) model_fine.to(device) # Load checkpoint checkpoint = torch.load(configargs.checkpoint, map_location=device) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] # Prepare model and data model_coarse.eval() if model_fine: model_fine.eval() render_poses = render_poses.float().to(device) print("Dual render?", configargs.dual_render) # Evaluation loop with torch.no_grad(): fine_psnrs = [] if type(i_test) != list: i_test = [i_test] for i in i_test: print( f"Test sample {i + 1 - i_test[0]}/{i_test[-1] - i_test[0]}...") img_target = images[i].to(device) pose_target = poses[i, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, dual_render=configargs.dual_render) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) loss, fine_loss = 0.0, 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = fine_loss else: loss = coarse_loss loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) psnr_coarse = mse2psnr(coarse_loss) psnr_fine = mse2psnr(fine_loss) print( f"\t Loss at sample: {psnr} (f:{psnr_fine}, c:{psnr_coarse})") fine_psnrs.append(psnr_fine) print(f"Validation PSNR: {sum(fine_psnrs) / len(fine_psnrs)}.")