def evaluation(cfg, mlp, data): expid = cfg.experiment.id logdir = cfg.experiment.logdir save_dir = os.path.join(logdir, expid, 'rendered') os.makedirs(save_dir, exist_ok=True) with torch.no_grad(): for i in trange(data.render_poses.shape[0]): pose = data.render_poses[i] _, ray_ori, ray_dir = get_ray_bundle(data.height, data.width, data.focal, pose) img_shape = ray_ori.shape ray_ori = ray_ori.reshape((-1, ray_ori.shape[-1])) ray_dir = ray_dir.reshape((-1, ray_dir.shape[-1])) rgb_coarse, rgb_fine = nerf_iteration(mlp, cfg, torch.cat([ray_ori, ray_dir], dim=-1), data.near, data.far, mode='validation') if rgb_fine is not None: rgb = rgb_fine.reshape(img_shape) else: rgb = rgb_coarse.reshape(img_shape) save_file = os.path.join(save_dir, f"{i:04d}.png") img = np.array(to_pil_image(rgb.permute(2, 0, 1).cpu())) imageio.imwrite(save_file, img)
def get_valid_img(self): idx = np.random.choice(self.i_val) img = self.images[idx] pose = self.poses[idx] _, ray_ori, ray_dir = get_ray_bundle(self.height, self.width, self.focal, pose) ray_ori = ray_ori.reshape((-1, ray_ori.shape[-1])) ray_dir = ray_dir.reshape((-1, ray_dir.shape[-1])) rays = torch.cat([ray_ori, ray_dir], dim=-1) return rays, img
def convert_poses_to_rays(poses, H, W, focal): ray_origins = [] ray_directions = [] for pose in poses: chunk_ray_origins, chunk_ray_directions = get_ray_bundle( H, W, focal, pose) ray_origins.append(chunk_ray_origins) ray_directions.append(chunk_ray_directions) ray_origins = torch.stack(ray_origins, 0) ray_directions = torch.stack(ray_directions, 0) return ray_origins, ray_directions
def run_one_iter_of_tinynerf( height, width, focal_length, tform_cam2world, near_thresh, far_thresh, depth_samples_per_ray, encoding_function, get_minibatches_function, chunksize, model, encoding_function_args, ): # Get the "bundle" of rays through all image pixels. ray_origins, ray_directions = get_ray_bundle( height, width, focal_length, tform_cam2world ) # Sample query points along each ray query_points, depth_values = compute_query_points_from_rays( ray_origins, ray_directions, near_thresh, far_thresh, depth_samples_per_ray ) # "Flatten" the query points. flattened_query_points = query_points.reshape((-1, 3)) # Encode the query points (default: positional encoding). encoded_points = encoding_function(flattened_query_points, encoding_function_args) # Split the encoded points into "chunks", run the model on all chunks, and # concatenate the results (to avoid out-of-memory issues). batches = get_minibatches_function(encoded_points, chunksize=chunksize) predictions = [] for batch in batches: predictions.append(model(batch)) radiance_field_flattened = torch.cat(predictions, dim=0) # "Unflatten" to obtain the radiance field. unflattened_shape = list(query_points.shape[:-1]) + [4] radiance_field = torch.reshape(radiance_field_flattened, unflattened_shape) # Perform differentiable volume rendering to re-synthesize the RGB image. rgb_predicted, _, _ = render_volume_density( radiance_field, ray_origins, depth_values ) return rgb_predicted
def test_get_ray_bundle(): tfrom_cam2world = np.eye(4, 4, dtype=np.float32) tfrom_cam2world[0, 3] = 2.0 tfrom_cam2world[1, 3] = -3.0 tfrom_cam2world[2, 3] = 5.0 for i in range(0, 2): jax_fn = lambda x: get_ray_bundle(10, 10, 0.3, x)[i] torch_fn = lambda x: get_ray_bundle_torch(10, 10, 0.3, x)[i] jo, to, djos, dtos = run_and_grad(jax_fn, torch_fn, (0, ), tfrom_cam2world) assert np.allclose(jo, to, rtol=1e-3, atol=1e-5) assert all( np.allclose(djo, dto, rtol=1e-3, atol=1e-5) for djo, dto in zip(djos, dtos))
def get_train_batch(self, size=1024): idx = np.random.choice(self.i_train) img = self.images[idx] pose = self.poses[idx] coords, ray_ori, ray_dir = get_ray_bundle(self.height, self.width, self.focal, pose) ray_idx = np.random.choice( # take a subset of all rays coords.shape[0], size=size, replace=False) coords = coords[ray_idx] ray_ori = ray_ori[coords[:, 0], coords[:, 1]] ray_dir = ray_dir[coords[:, 0], coords[:, 1]] img = img[coords[:, 0], coords[:, 1]] rays = torch.cat([ray_ori, ray_dir], dim=-1) return rays, img
def cache_dataset(self): # TODO(0) testskip = args.blender_stride, offset for a small dataset # Unpacking data bundle = self.load_dataset().to(self.device) # Coordinates to sample from self.init_sampling(bundle.hwf) for img_idx in trange(bundle.size): # Create data chunk bundle sample = bundle[img_idx] sample.ray_origins, sample.ray_directions = get_ray_bundle( *sample.hwf, sample.poses) if self.cfg.dataset.use_ndc: # Use normalized device coordinates sample.ndc() if self.cfg.dataset.caching.sample_all or self.type == DatasetType.VALIDATION: self.save_dataset(sample, img_idx) else: raise NotImplementedError
def cache_nerf_dataset(args): images, poses, render_poses, hwf = ( None, None, None, None, ) i_train, i_val, i_test = None, None, None if args.type == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( args.datapath, half_res=args.blender_half_res, testskip=args.blender_stride) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] elif args.type == "llff": images, poses, bds, render_poses, i_test = load_llff_data( args.datapath, factor=args.llff_downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf hwf = [int(H), int(W), focal] images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Device on which to run. if torch.cuda.is_available(): device = "cuda" else: device = "cpu" os.makedirs(os.path.join(args.savedir, "train"), exist_ok=True) os.makedirs(os.path.join(args.savedir, "val"), exist_ok=True) os.makedirs(os.path.join(args.savedir, "test"), exist_ok=True) np.random.seed(args.randomseed) for img_idx in tqdm(i_train): for j in range(args.num_variations): img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H).to(device), torch.arange(W).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) target_s = None save_path = None if args.sample_all is False: select_inds = np.random.choice(coords.shape[0], size=(args.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] save_path = os.path.join( args.savedir, "train", str(img_idx).zfill(4), str(j).zfill(4), ".data", ) else: target_s = img_target save_path = os.path.join(args.savedir, "train", str(img_idx).zfill(4) + ".data") batch_rays = torch.stack([ray_origins, ray_directions], dim=0) cache_dict = { "height": H, "width": W, "focal_length": focal, "ray_bundle": batch_rays.detach().cpu(), "target": target_s.detach().cpu(), } save_path = os.path.join(args.savedir, "train", str(img_idx).zfill(4) + ".data") torch.save(cache_dict, save_path) if args.sample_all is True: break for img_idx in tqdm(i_val): img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle(H, W, focal, pose_target) cache_dict = { "height": H, "width": W, "focal_length": focal, "ray_origins": ray_origins.detach().cpu(), "ray_directions": ray_directions.detach().cpu(), "target": img_target.detach().cpu(), } save_path = os.path.join(args.savedir, "val", str(img_idx).zfill(4) + ".data") torch.save(cache_dict, save_path)
def export_ray_trace(model_coarse, model_fine, config_args, cfg, encode_position_fn, encode_direction_fn, device): # Mesh Extraction samples_dimen_y = 8 samples_dimen_x = 4 plane_near = 0 plane_far = 4.0 img_size = 800 step_size = 2 dist_threshold = 0.002 prob_threshold = 0.6 # Data vertices, triangles, normals, diffuse = [], [], [], [] render_poses = torch.stack([ torch.from_numpy(pose_spherical(angleY, angleX, plane_far)).float() for angleY in np.linspace(-180, 180, samples_dimen_y, endpoint=False) for angleX in np.linspace(-90, 90, samples_dimen_x, endpoint=True) ], dim=0) hwf = [img_size, img_size, 1111.1111] grid = get_grid(img_size) for i, pose in enumerate(tqdm(render_poses)): pose = pose[:3, :4].to(device) # Ray origins & directions ray_origins, ray_directions = get_ray_bundle(hwf[0], hwf[1], hwf[2], pose) # cfg.nerf['validation']['num_coarse'] = 64 # cfg.nerf['validation']['num_fine'] = 64 with torch.no_grad(): _, _, _, rgb_fine, _, depth_fine = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) # Apply nn mask surface_points = ray_origins + ray_directions * depth_fine[..., None] acc = [] initial_values = surface_points[grid.T[0], grid.T[1]].view( img_size, img_size, -1) size = step_size * 2 + 1 size_samples = size**2 - 1 for a in range(-step_size, step_size + 1): for b in range(-step_size, step_size + 1): offset = torch.tensor([a, b]) new_grid = grid + offset new_grid = new_grid.clamp(0, img_size - 1) new_grid = surface_points[new_grid.T[0], new_grid.T[1]].view( img_size, img_size, -1) new_grid_s = ((new_grid - initial_values)** 2).sum(-1) < dist_threshold acc.append(new_grid_s) new_mask = torch.stack( acc, -1).sum(-1).squeeze(-1) > size_samples * prob_threshold dep_mask = (depth_fine > 0) mask = new_mask * dep_mask ray_origins, ray_directions, depth_fine = ray_origins[ mask], ray_directions[mask], depth_fine[mask] rgb_fine = rgb_fine[mask] surface_points = ray_origins + ray_directions * depth_fine[..., None] vertices.append(surface_points.view(-1, 3).cpu().detach()) normals.append((-ray_directions).view(-1, 3).cpu().detach()) diffuse.append(rgb_fine.view(-1, 3).cpu().detach()) # Query the whole diffuse map diffuse_fine = torch.cat(diffuse, dim=0).numpy() vertices_fine = torch.cat(vertices, dim=0).numpy() normals_fine = torch.cat(normals, dim=0).numpy() # Export model # export_obj(vertices_fine, [], diffuse_fine, normals_fine, "lego-sampling.obj") export_ply(vertices_fine, diffuse_fine, normals_fine, "lego-sampling.ply")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # If a pre-cached dataset is available, skip the dataloader. USE_CACHED_DATASET = False train_paths, validation_paths = None, None images, poses, render_poses, hwf, i_split = None, None, None, None, None H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None if hasattr(cfg.dataset, "cachedir") and os.path.exists( cfg.dataset.cachedir): train_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "train", "*.data")) validation_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "val", "*.data")) USE_CACHED_DATASET = True else: # Load dataset images, poses, render_poses, hwf = None, None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + ( 1. - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. if torch.cuda.is_available(): device = "cuda" else: device = "cpu" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) if model_fine is not None: trainable_parameters += list(model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters, lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() if model_fine: model_coarse.train() rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(train_paths) cache_dict = torch.load(datafile) ray_bundle = cache_dict["ray_bundle"].to(device) ray_origins, ray_directions = ( ray_bundle[0].reshape((-1, 3)), ray_bundle[1].reshape((-1, 3)), ) target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3)) select_inds = np.random.choice( ray_origins.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False, ) ray_origins, ray_directions = ( ray_origins[select_inds], ray_directions[select_inds], ) target_ray_values = target_ray_values[select_inds].to(device) # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) else: img_idx = np.random.choice(i_train) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H).to(device), torch.arange(W).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) select_inds = np.random.choice( coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] # batch_rays = torch.stack([ray_origins, ray_directions], dim=0) target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] then = time.time() rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = None if rgb_fine is not None: fine_loss = torch.nn.functional.mse_loss( rgb_fine[..., :3], target_ray_values[..., :3]) # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3]) loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0) loss.backward() psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " PSNR: " + str(psnr)) writer.add_scalar("train/loss", loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) if rgb_fine is not None: writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1): tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() if model_fine: model_coarse.eval() start = time.time() with torch.no_grad(): rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(validation_paths) cache_dict = torch.load(datafile) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, cache_dict["ray_origins"].to(device), cache_dict["ray_directions"].to(device), cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = cache_dict["target"].to(device) else: img_idx = np.random.choice(i_val) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validataion/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) if rgb_fine is not None: writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": None if not model_fine else model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.", ) parser.add_argument("--savedir", type=str, help="Save images to this directory, if specified.") parser.add_argument("--save-disparity-image", action="store_true", help="Save disparity images too.") configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": # Load blender dataset images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) elif cfg.dataset.type.lower() == "llff": # Load LLFF dataset images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) checkpoint = torch.load(configargs.checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] model_coarse.eval() if model_fine: model_fine.eval() render_poses = render_poses.float().to(device) # Create directory to save images to. os.makedirs(configargs.savedir, exist_ok=True) if configargs.save_disparity_image: os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True) # Evaluation loop times_per_image = [] for i, pose in enumerate(tqdm(render_poses)): start = time.time() rgb = None, None disp = None, None with torch.no_grad(): pose = pose[:3, :4] ray_origins, ray_directions = get_ray_bundle( hwf[0], hwf[1], hwf[2], pose) rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _ = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) rgb = rgb_fine if rgb_fine is not None else rgb_coarse if configargs.save_disparity_image: disp = disp_fine if disp_fine is not None else disp_coarse times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") imageio.imwrite( savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) if configargs.save_disparity_image: savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") imageio.imwrite(savefile, cast_to_disparity_image(disp)) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # If a pre-cached dataset is available, skip the dataloader. USE_CACHED_DATASET = False train_paths, validation_paths = None, None images, poses, render_poses, hwf, i_split, expressions = None, None, None, None, None, None H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None if hasattr(cfg.dataset, "cachedir") and os.path.exists( cfg.dataset.cachedir): train_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "train", "*.data")) validation_paths = glob.glob( os.path.join(cfg.dataset.cachedir, "val", "*.data")) USE_CACHED_DATASET = True else: # Load dataset images, poses, render_poses, hwf, expressions = None, None, None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split, expressions, _, bboxs = load_flame_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + ( 1.0 - images[..., -1:]) print("done loading data") # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. if torch.cuda.is_available(): device = "cuda" #+ ":" + str(cfg.experiment.device) else: device = "cpu" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_fine.to(device) ################################### ################################### train_background = False supervised_train_background = False blur_background = False train_latent_codes = True disable_expressions = False # True to disable expressions disable_latent_codes = False # True to disable latent codes fixed_background = True # Do False to disable BG regularize_latent_codes = True # True to add latent code LOSS, false for most experiments ################################### ################################### supervised_train_background = train_background and supervised_train_background # Avg background #images[i_train] if train_background: with torch.no_grad(): avg_img = torch.mean(images[i_train], axis=0) # Blur Background: if blur_background: avg_img = avg_img.permute(2, 0, 1) avg_img = avg_img.unsqueeze(0) smoother = GaussianSmoothing(channels=3, kernel_size=11, sigma=11) print("smoothed background initialization. shape ", avg_img.shape) avg_img = smoother(avg_img).squeeze(0).permute(1, 2, 0) #avg_img = torch.zeros(H,W,3) #avg_img = torch.rand(H,W,3) #avg_img = 0.5*(torch.rand(H,W,3) + torch.mean(images[i_train],axis=0)) background = torch.tensor(avg_img, device=device) background.requires_grad = True if fixed_background: # load GT background print("loading GT background to condition on") from PIL import Image background = Image.open( os.path.join(cfg.dataset.basedir, 'bg', '00050.png')) background.thumbnail((H, W)) background = torch.from_numpy(np.array(background).astype( np.float32)).to(device) background = background / 255 print("bg shape", background.shape) print("should be ", images[i_train][0].shape) assert background.shape == images[i_train][0].shape else: background = None # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) if model_fine is not None: trainable_parameters += list(model_fine.parameters()) if train_background: #background.requires_grad = True #trainable_parameters.append(background) # add it later when init optimizer for different lr print("background.is_leaf ", background.is_leaf, background.device) if train_latent_codes: latent_codes = torch.zeros(len(i_train), 32, device=device) print("initialized latent codes with shape %d X %d" % (latent_codes.shape[0], latent_codes.shape[1])) if not disable_latent_codes: trainable_parameters.append(latent_codes) latent_codes.requires_grad = True if train_background: optimizer = getattr(torch.optim, cfg.optimizer.type)([{ 'params': trainable_parameters }, { 'params': background, 'lr': cfg.optimizer.lr }], lr=cfg.optimizer.lr) else: optimizer = getattr(torch.optim, cfg.optimizer.type)( [{ 'params': trainable_parameters }, { 'params': background, 'lr': cfg.optimizer.lr }], # this is obsolete but need for continuing training lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) if checkpoint["background"] is not None: print("loaded bg from checkpoint") background = torch.nn.Parameter( checkpoint['background'].to(device)) if checkpoint["latent_codes"] is not None: print("loaded latent codes from checkpoint") latent_codes = torch.nn.Parameter( checkpoint['latent_codes'].to(device)) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays # Prepare importance sampling maps ray_importance_sampling_maps = [] p = 0.9 print("computing boundix boxes probability maps") for i in i_train: bbox = bboxs[i] probs = np.zeros((H, W)) probs.fill(1 - p) probs[bbox[0]:bbox[1], bbox[2]:bbox[3]] = p probs = (1 / probs.sum()) * probs ray_importance_sampling_maps.append(probs.reshape(-1)) print("Starting loop") for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() if model_fine: model_coarse.train() rgb_coarse, rgb_fine = None, None target_ray_values = None background_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(train_paths) cache_dict = torch.load(datafile) ray_bundle = cache_dict["ray_bundle"].to(device) ray_origins, ray_directions = ( ray_bundle[0].reshape((-1, 3)), ray_bundle[1].reshape((-1, 3)), ) target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3)) select_inds = np.random.choice( ray_origins.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False, ) ray_origins, ray_directions = ( ray_origins[select_inds], ray_directions[select_inds], ) target_ray_values = target_ray_values[select_inds].to(device) #target_ray_values = target_ray_values[select_inds].to(device) # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expressions) else: img_idx = np.random.choice(i_train) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) if not disable_expressions: expression_target = expressions[img_idx].to(device) # vector else: # zero expr expression_target = torch.zeros(76, device=device) #bbox = bboxs[img_idx] if not disable_latent_codes: latent_code = latent_codes[img_idx].to( device) if train_latent_codes else None else: latent_codes = torch.zeros(32, device=device) #latent_code = torch.zeros(32).to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H).to(device), torch.arange(W).to(device)), dim=-1, ) # Only randomly choose rays that are in the bounding box ! # coords = torch.stack( # meshgrid_xy(torch.arange(bbox[0],bbox[1]).to(device), torch.arange(bbox[2],bbox[3]).to(device)), # dim=-1, # ) coords = coords.reshape((-1, 2)) # select_inds = np.random.choice( # coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False # ) # Use importance sampling to sample mainly in the bbox with prob p select_inds = np.random.choice( coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False, p=ray_importance_sampling_maps[img_idx]) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] #dump_rays(ray_origins, ray_directions) # batch_rays = torch.stack([ray_origins, ray_directions], dim=0) target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] background_ray_values = background[select_inds[:, 0], select_inds[:, 1], :] if ( train_background or fixed_background) else None #if i<10000: # background_ray_values = None #background_ray_values = None then = time.time() rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression_target, background_prior=background_ray_values, latent_code=latent_code if not disable_latent_codes else torch.zeros(32, device=device)) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = None if rgb_fine is not None: fine_loss = torch.nn.functional.mse_loss( rgb_fine[..., :3], target_ray_values[..., :3]) # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3]) loss = 0.0 # if fine_loss is not None: # loss = fine_loss # else: # loss = coarse_loss latent_code_loss = torch.zeros(1, device=device) if train_latent_codes and not disable_latent_codes: latent_code_loss = torch.norm(latent_code) * 0.0005 #latent_code_loss = torch.zeros(1) background_loss = torch.zeros(1, device=device) if supervised_train_background: background_loss = torch.nn.functional.mse_loss( background_ray_values[..., :3], target_ray_values[..., :3], reduction='none').sum(1) background_loss = torch.mean(background_loss * weights) * 0.001 loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0) psnr = mse2psnr(loss.item()) #loss_total = loss #+ (latent_code_loss if latent_code_loss is not None else 0.0) loss = loss + (latent_code_loss * 10 if regularize_latent_codes else 0.0) loss_total = loss + (background_loss if supervised_train_background is not None else 0.0) #loss.backward() loss_total.backward() #psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " BG Loss: " + str(background_loss.item()) + " PSNR: " + str(psnr) + " LatentReg: " + str(latent_code_loss.item())) #writer.add_scalar("train/loss", loss.item(), i) if train_latent_codes: writer.add_scalar("train/code_loss", latent_code_loss.item(), i) if supervised_train_background: writer.add_scalar("train/bg_loss", background_loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) if rgb_fine is not None: writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1 and False): #torch.cuda.empty_cache() tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() if model_fine: model_coarse.eval() start = time.time() with torch.no_grad(): rgb_coarse, rgb_fine = None, None target_ray_values = None if USE_CACHED_DATASET: datafile = np.random.choice(validation_paths) cache_dict = torch.load(datafile) rgb_coarse, _, _, rgb_fine, _, weights = run_one_iter_of_nerf( cache_dict["height"], cache_dict["width"], cache_dict["focal_length"], model_coarse, model_fine, cache_dict["ray_origins"].to(device), cache_dict["ray_directions"].to(device), cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression_target, latent_code=torch.zeros(32, device=device)) target_ray_values = cache_dict["target"].to(device) else: # Do all validation set... loss = 0 for img_idx in i_val[:2]: img_target = images[img_idx].to(device) #tqdm.set_description('val im %d' % img_idx) #tqdm.refresh() # to show immediately the update # # save val image for debug ### DEBUG #### # #GT = target_ray_values[..., :3] # import PIL.Image # #img = GT.permute(2, 0, 1) # # Conver to PIL Image and then np.array (output shape: (H, W, 3)) # #im_numpy = img_target.detach().cpu().numpy() # #im_numpy = np.array(torchvision.transforms.ToPILImage()(img_target.detach().cpu())) # # # im = PIL.Image.fromarray(im_numpy) # im = img_target # im = im.permute(2, 0, 1) # img = np.array(torchvision.transforms.ToPILImage()(im.detach().cpu())) # im = PIL.Image.fromarray(img) # im.save('val_im_target_debug.png') # ### DEBUG #### END pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression_target, background_prior=background.view(-1, 3) if (train_background or fixed_background) else None, latent_code=torch.zeros(32).to(device) if train_latent_codes or disable_latent_codes else None, ) #print("did one val") target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) curr_loss, curr_fine_loss = 0.0, 0.0 if rgb_fine is not None: curr_fine_loss = img2mse( rgb_fine[..., :3], target_ray_values[..., :3]) curr_loss = curr_fine_loss else: curr_loss = coarse_loss loss += curr_loss + curr_fine_loss loss /= len(i_val) psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validation/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) if rgb_fine is not None: writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) if train_background or fixed_background: writer.add_image("validation/background", cast_to_image(background[..., :3]), i) writer.add_image("validation/weights", (weights.detach().cpu().numpy()), i, dataformats='HW') tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) #gpu_profile(frame=sys._getframe(), event='line', arg=None) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": None if not model_fine else model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, "background": None if not (train_background or fixed_background) else background.data, "latent_codes": None if not train_latent_codes else latent_codes.data } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.", ) parser.add_argument("--savedir", type=str, default='./renders/', help="Save images to this directory, if specified.") parser.add_argument("--save-disparity-image", action="store_true", help="Save disparity images too.") parser.add_argument("--save-error-image", action="store_true", help="Save photometric error visualization") configargs = parser.parse_args() # Read config file. cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": # Load blender dataset images, poses, render_poses, hwf, i_split, expressions, _, _ = load_flame_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, test=True) #i_train, i_val, i_test = i_split i_test = i_split H, W, focal = hwf H, W = int(H), int(W) elif cfg.dataset.type.lower() == "llff": # Load LLFF dataset images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_coarse.to(device) # If a fine-resolution model is specified, initialize it. model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, num_layers=cfg.models.coarse.num_layers, hidden_size=cfg.models.coarse.hidden_size, include_expression=True) model_fine.to(device) checkpoint = torch.load(configargs.checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] if "background" in checkpoint.keys(): background = checkpoint["background"] if background is not None: print("loaded background with shape ", background.shape) background.to(device) if "latent_codes" in checkpoint.keys(): latent_codes = checkpoint["latent_codes"] use_latent_code = False if latent_codes is not None: use_latent_code = True latent_codes.to(device) print("loading index map for latent codes...") idx_map = np.load(cfg.dataset.basedir + "/index_map.npy").astype(int) print("loaded latent codes from checkpoint, with shape ", latent_codes.shape) model_coarse.eval() if model_fine: model_fine.eval() replace_background = True if replace_background: from PIL import Image #background = Image.open('./view.png') background = Image.open(cfg.dataset.basedir + '/bg/00050.png') #background = Image.open("./real_data/andrei_dvp/" + '/bg/00050.png') background.thumbnail((H, W)) background = torch.from_numpy( np.array(background).astype(float)).to(device) background = background / 255 print('loaded custom background of shape', background.shape) #background = torch.ones_like(background) #background.permute(2,0,1) render_poses = render_poses.float().to(device) # Create directory to save images to. os.makedirs(configargs.savedir, exist_ok=True) if configargs.save_disparity_image: os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True) if configargs.save_error_image: os.makedirs(os.path.join(configargs.savedir, "error"), exist_ok=True) os.makedirs(os.path.join(configargs.savedir, "normals"), exist_ok=True) # Evaluation loop times_per_image = [] #render_poses = render_poses.float().to(device) render_poses = poses[i_test].float().to(device) #expressions = torch.arange(-6,6,0.5).float().to(device) render_expressions = expressions[i_test].float().to(device) #avg_img = torch.mean(images[i_train],axis=0) #avg_img = torch.ones_like(avg_img) #pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) #for i, pose in enumerate(tqdm(render_poses)): index_of_image_after_train_shuffle = 0 # render_expressions = render_expressions[[300]] ### TODO render specific expression ####################### no_background = False no_expressions = False no_lcode = False nerf = False frontalize = False interpolate_mouth = False ####################### if nerf: no_background = True no_expressions = True no_lcode = True if no_background: background = None if no_expressions: render_expressions = torch.zeros_like(render_expressions, device=render_expressions.device) if no_lcode: use_latent_code = True latent_codes = torch.zeros(5000, 32, device=device) for i, expression in enumerate(tqdm(render_expressions)): #for i in range(75,151): #if i%25 != 0: ### TODO generate only every 25th im #if i != 511: ### TODO generate only every 25th im # continue start = time.time() rgb = None, None disp = None, None with torch.no_grad(): pose = render_poses[i] if interpolate_mouth: frame_id = 241 num_images = 150 pose = render_poses[241] expression = render_expressions[241].clone() expression[68] = torch.arange(-1, 1, 2 / 150, device=device)[i] if frontalize: pose = render_poses[0] #pose = render_poses[300] ### TODO fixes pose #expression = render_expressions[0] ### TODO fixes expr #expression = torch.zeros_like(expression).to(device) ablate = 'view_dir' if ablate == 'expression': pose = render_poses[100] elif ablate == 'latent_code': pose = render_poses[100] expression = render_expressions[100] if idx_map[100 + i, 1] >= 0: #print("found latent code for this image") index_of_image_after_train_shuffle = idx_map[100 + i, 1] elif ablate == 'view_dir': pose = render_poses[100] expression = render_expressions[100] _, ray_directions_ablation = get_ray_bundle( hwf[0], hwf[1], hwf[2], render_poses[240 + i][:3, :4]) pose = pose[:3, :4] #pose = torch.from_numpy(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])) if use_latent_code: if idx_map[i, 1] >= 0: #print("found latent code for this image") index_of_image_after_train_shuffle = idx_map[i, 1] #index_of_image_after_train_shuffle = 10 ## TODO Fixes latent code #index_of_image_after_train_shuffle = idx_map[84,1] ## TODO Fixes latent code v2 for andrei index_of_image_after_train_shuffle = idx_map[ 10, 1] ## TODO Fixes latent code - USE THIS if not ablating! latent_code = latent_codes[index_of_image_after_train_shuffle].to( device) if use_latent_code else None #latent_code = torch.mean(latent_codes) ray_origins, ray_directions = get_ray_bundle( hwf[0], hwf[1], hwf[2], pose) rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _, weights = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, expressions=expression, background_prior=background.view(-1, 3) if (background is not None) else None, #background_prior = torch.ones_like(background).view(-1,3), # White background latent_code=latent_code, ray_directions_ablation=ray_directions_ablation) rgb = rgb_fine if rgb_fine is not None else rgb_coarse normals = torch_normal_map(disp_fine, focal, weights, clean=True) #normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy()) save_plt_image( normals.cpu().numpy().astype('uint8'), os.path.join(configargs.savedir, 'normals', f"{i:04d}.png")) #if configargs.save_disparity_image: if False: disp = disp_fine if disp_fine is not None else disp_coarse #normals = normal_map_from_depth_map_backproject(disp.cpu().numpy()) normals = normal_map_from_depth_map_backproject( disp_fine.cpu().numpy()) save_plt_image( normals.astype('uint8'), os.path.join(configargs.savedir, 'normals', f"{i:04d}.png")) #if configargs.save_normal_image: # normal_map_from_depth_map_backproject(disp_fine.cpu().numpy()) #rgb[torch.where(weights>0.25)]=1.0 #rgb[torch.where(weights>0.1)] = (rgb * weights + (torch.ones_like(weights)-weights)*torch.ones_like(weights)) times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") imageio.imwrite( savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) if configargs.save_disparity_image: savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") imageio.imwrite(savefile, cast_to_disparity_image(disp_fine)) if configargs.save_error_image: savefile = os.path.join(configargs.savedir, "error", f"{i:04d}.png") GT = images[i_test][i] fig = error_image(GT, rgb.cpu().numpy()) #imageio.imwrite(savefile, cast_to_disparity_image(disp)) plt.savefig(savefile, pad_inches=0, bbox_inches='tight', dpi=54) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument( "--load-checkpoint", type=str, default="", help="Path to load saved checkpoint from.", ) configargs = parser.parse_args() # Read config file. with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # # (Optional:) enable this to track autograd issues when debugging # torch.autograd.set_detect_anomaly(True) # Load dataset if cfg.dataset.type.lower() == "llff": images, poses, _, _, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(images.shape[0]) if (i not in i_test and i not in i_val) ]) H, W, focal = hwf H, W = int(H), int(W) images = torch.from_numpy(images) poses = torch.from_numpy(poses) USE_HR_LR = False # Load LR images if hasattr(cfg.dataset, "relative_lr_factor"): assert hasattr(cfg.dataset, "hr_fps") assert hasattr(cfg.dataset, "hr_frequency") USE_HR_LR = True images_lr, poses_lr, _, _, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor * cfg.dataset.relative_lr_factor) hwf = poses_lr[0, :3, -1] poses_lr = poses_lr[:, :3, :4] if not isinstance(i_test, list): i_test = [i_test] if cfg.dataset.llffhold > 0: i_test = np.arange(images_lr.shape[0])[::cfg.dataset.llffhold] i_train_lr = np.array([ i for i in np.arange(images_lr.shape[0]) if (i not in i_test) ]) H_lr, W_lr, focal_lr = hwf H_lr, W_lr = int(H_lr), int(W_lr) images_lr = torch.from_numpy(images_lr) poses_lr = torch.from_numpy(poses_lr) # Expose only some HR images i_train = i_train[::cfg.dataset.hr_fps] print(f'LR summary: N={i_train_lr.shape[0]}, ' f'resolution={H_lr}x{W_lr}, focal-length={focal_lr}') print(f'HR summary: N={i_train.shape[0]}, ' f'resolution={H}x{W}, focal-length={focal}') # Seed experiment for repeatability seed = cfg.experiment.randomseed np.random.seed(seed) torch.manual_seed(seed) # Device on which to run. device = f"cuda:{cfg.experiment.gpu_id}" encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz, ) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir, ) # Initialize a coarse and fine resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs, ) model_coarse.to(device) model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs, ) model_fine.to(device) # Initialize optimizer. trainable_parameters = list(model_coarse.parameters()) + list( model_fine.parameters()) optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters, lr=cfg.optimizer.lr) # Setup logging. logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) # Write out config parameters. with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) # By default, start at iteration 0 (unless a checkpoint is specified). start_iter = 0 # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() model_fine.train() if USE_HR_LR and i % cfg.dataset.hr_frequency != 0: H_iter, W_iter, focal_iter = H_lr, W_lr, focal_lr i_train_iter = i_train_lr images_iter, poses_iter = images_lr, poses_lr else: H_iter, W_iter, focal_iter = H, W, focal i_train_iter = i_train images_iter, poses_iter = images, poses img_idx = np.random.choice(i_train_iter) img_target = images_iter[img_idx].to(device) pose_target = poses_iter[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle(H_iter, W_iter, focal_iter, pose_target) coords = torch.stack( meshgrid_xy( torch.arange(H_iter).to(device), torch.arange(W_iter).to(device)), dim=-1, ) coords = coords.reshape((-1, 2)) select_inds = np.random.choice(coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H_iter, W_iter, focal_iter, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="train", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = target_s coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = torch.nn.functional.mse_loss(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss loss.backward() psnr = mse2psnr(loss.item()) optimizer.step() optimizer.zero_grad() # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor **(i / num_decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = lr_new if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1: tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " + str(loss.item()) + " PSNR: " + str(psnr)) writer.add_scalar("train/loss", loss.item(), i) writer.add_scalar("train/coarse_loss", coarse_loss.item(), i) writer.add_scalar("train/fine_loss", fine_loss.item(), i) writer.add_scalar("train/psnr", psnr, i) # Validation if (i % cfg.experiment.validate_every == 0 or i == cfg.experiment.train_iters - 1): tqdm.write("[VAL] =======> Iter: " + str(i)) model_coarse.eval() model_fine.eval() start = time.time() with torch.no_grad(): img_idx = np.random.choice(i_val) img_target = images[img_idx].to(device) pose_target = poses[img_idx, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i) writer.add_scalar("validation/coarse_loss", coarse_loss.item(), i) writer.add_scalar("validation/fine_loss", fine_loss.item(), i) writer.add_scalar("validataion/psnr", psnr, i) writer.add_image("validation/rgb_coarse", cast_to_image(rgb_coarse[..., :3]), i) writer.add_image("validation/rgb_fine", cast_to_image(rgb_fine[..., :3]), i) writer.add_image( "validation/img_target", cast_to_image(target_ray_values[..., :3]), i, ) tqdm.write("Validation loss: " + str(loss.item()) + " Validation PSNR: " + str(psnr) + " Time: " + str(time.time() - start)) if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1: checkpoint_dict = { "iter": i, "model_coarse_state_dict": model_coarse.state_dict(), "model_fine_state_dict": model_fine.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "psnr": psnr, } torch.save( checkpoint_dict, os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"), ) tqdm.write("================== Saved Checkpoint =================") print("Done!")
def main(): # Config options: parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint / pre-trained model to evaluate.") parser.add_argument("--dual-render", action='store_true', default=False) parser.add_argument('--gpu-id', type=int, default=0, help="id of the CUDA GPU to use (default: 0)") configargs = parser.parse_args() cfg = None with open(configargs.config, "r") as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = CfgNode(cfg_dict) # Dataset: images, poses, render_poses, hwf = None, None, None, None i_train, i_val, i_test = None, None, None if cfg.dataset.type.lower() == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( cfg.dataset.basedir, half_res=cfg.dataset.half_res, testskip=cfg.dataset.testskip, ) i_train, i_val, i_test = i_split H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if cfg.nerf.train.white_background: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) elif cfg.dataset.type.lower() == "llff": images, poses, bds, render_poses, i_test = load_llff_data( cfg.dataset.basedir, factor=cfg.dataset.downsample_factor, ) hwf = poses[0, :3, -1] H, W, focal = hwf hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) images = torch.from_numpy(images) poses = torch.from_numpy(poses) # Hardware device = "cpu" if torch.cuda.is_available(): torch.cuda.set_device(configargs.gpu_id) device = "cuda" # Model encode_position_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, include_input=cfg.models.coarse.include_input_xyz, log_sampling=cfg.models.coarse.log_sampling_xyz) encode_direction_fn = None if cfg.models.coarse.use_viewdirs: encode_direction_fn = get_embedding_function( num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, include_input=cfg.models.coarse.include_input_dir, log_sampling=cfg.models.coarse.log_sampling_dir) model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir, include_input_xyz=cfg.models.coarse.include_input_xyz, include_input_dir=cfg.models.coarse.include_input_dir, use_viewdirs=cfg.models.coarse.use_viewdirs) model_coarse.to(device) model_fine = None if hasattr(cfg.models, "fine"): model_fine = getattr(models, cfg.models.fine.type)( num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz, num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir, include_input_xyz=cfg.models.fine.include_input_xyz, include_input_dir=cfg.models.fine.include_input_dir, use_viewdirs=cfg.models.fine.use_viewdirs) model_fine.to(device) # Load checkpoint checkpoint = torch.load(configargs.checkpoint, map_location=device) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: try: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) except: print("The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file.") if "height" in checkpoint.keys(): hwf[0] = checkpoint["height"] if "width" in checkpoint.keys(): hwf[1] = checkpoint["width"] if "focal_length" in checkpoint.keys(): hwf[2] = checkpoint["focal_length"] # Prepare model and data model_coarse.eval() if model_fine: model_fine.eval() render_poses = render_poses.float().to(device) print("Dual render?", configargs.dual_render) # Evaluation loop with torch.no_grad(): fine_psnrs = [] if type(i_test) != list: i_test = [i_test] for i in i_test: print( f"Test sample {i + 1 - i_test[0]}/{i_test[-1] - i_test[0]}...") img_target = images[i].to(device) pose_target = poses[i, :3, :4].to(device) ray_origins, ray_directions = get_ray_bundle( H, W, focal, pose_target) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( H, W, focal, model_coarse, model_fine, ray_origins, ray_directions, cfg, mode="validation", encode_position_fn=encode_position_fn, encode_direction_fn=encode_direction_fn, dual_render=configargs.dual_render) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) loss, fine_loss = 0.0, 0.0 if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) loss = fine_loss else: loss = coarse_loss loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) psnr_coarse = mse2psnr(coarse_loss) psnr_fine = mse2psnr(fine_loss) print( f"\t Loss at sample: {psnr} (f:{psnr_fine}, c:{psnr_coarse})") fine_psnrs.append(psnr_fine) print(f"Validation PSNR: {sum(fine_psnrs) / len(fine_psnrs)}.")