def _load_renderings(self, args): """Load images from disk.""" if args.render_path: raise ValueError( "render_path cannot be used for the blender dataset.") with utils.open_file( path.join(args.data_dir, "transforms_{}.json".format(self.split)), "r") as fp: meta = json.load(fp) images = [] cams = [] for i in range(len(meta["frames"])): frame = meta["frames"][i] fname = os.path.join(args.data_dir, frame["file_path"] + ".png") with utils.open_file(fname, "rb") as imgin: image = np.array(Image.open(imgin), dtype=np.float32) / 255. if args.factor == 2: [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]] image = cv2.resize(image, (halfres_w, halfres_h), interpolation=cv2.INTER_AREA) elif args.factor > 0: raise ValueError( "Blender dataset only supports factor=0 or 2, {} " "set.".format(args.factor)) cams.append(frame["transform_matrix"]) images.append(image) self.images = np.stack(images, axis=0) if args.white_bkgd: self.images = ( self.images[Ellipsis, :3] * self.images[Ellipsis, -1:] + (1. - self.images[Ellipsis, -1:])) else: self.images = self.images[Ellipsis, :3] self.h, self.w = self.images.shape[1:3] self.resolution = self.h * self.w self.camtoworlds = np.stack(cams, axis=0) camera_angle_x = float(meta["camera_angle_x"]) self.focal = .5 * self.w / np.tan(.5 * camera_angle_x) self.n_examples = self.images.shape[0]
def _load_renderings(self, args): """Load images from disk.""" # Load images. imgdir_suffix = "" if args.factor > 0: imgdir_suffix = "_{}".format(args.factor) factor = args.factor else: factor = 1 imgdir = path.join(args.data_dir, "images" + imgdir_suffix) if not utils.file_exists(imgdir): raise ValueError("Image folder {} doesn't exist.".format(imgdir)) imgfiles = [ path.join(imgdir, f) for f in sorted(utils.listdir(imgdir)) if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") ] images = [] for imgfile in imgfiles: with utils.open_file(imgfile, "rb") as imgin: image = np.array(Image.open(imgin), dtype=np.float32) / 255. images.append(image) images = np.stack(images, axis=-1) # Load poses and bds. with utils.open_file(path.join(args.data_dir, "poses_bounds.npy"), "rb") as fp: poses_arr = np.load(fp) poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) bds = poses_arr[:, -2:].transpose([1, 0]) if poses.shape[-1] != images.shape[-1]: raise RuntimeError("Mismatch between imgs {} and poses {}".format( images.shape[-1], poses.shape[-1])) # Update poses according to downsampling. poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) poses[2, 4, :] = poses[2, 4, :] * 1. / factor # Correct rotation matrix ordering and move variable dim to axis 0. poses = np.concatenate( [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) poses = np.moveaxis(poses, -1, 0).astype(np.float32) images = np.moveaxis(images, -1, 0) bds = np.moveaxis(bds, -1, 0).astype(np.float32) # Rescale according to a default bd factor. scale = 1. / (bds.min() * .75) poses[:, :3, 3] *= scale bds *= scale # Recenter poses. poses = self._recenter_poses(poses) # Generate a spiral/spherical ray path for rendering videos. if args.spherify: poses = self._generate_spherical_poses(poses, bds) self.spherify = True else: self.spherify = False if not args.spherify and self.split == "test": self._generate_spiral_poses(poses, bds) # Select the split. i_test = np.arange(images.shape[0])[::args.llffhold] i_train = np.array( [i for i in np.arange(int(images.shape[0])) if i not in i_test]) if self.split == "train": indices = i_train else: indices = i_test images = images[indices] poses = poses[indices] self.images = images self.camtoworlds = poses[:, :3, :4] self.focal = poses[0, -1, -1] self.h, self.w = images.shape[1:3] self.resolution = self.h * self.w if args.render_path: self.n_examples = self.render_poses.shape[0] else: self.n_examples = images.shape[0]
def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables lpips_model = tf_hub.load(LPIPS_TFHUB_PATH) # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather( model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit( functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnr_values = [] ssim_values = [] lpips_values = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) lpips = compute_lpips(pred_color, batch["pixels"], lpips_model) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) lpips_values.append(float(lpips)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img(pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step) summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in lpips_values])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f: f.write("{}".format(np.mean(np.array(lpips_values)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def main(unused_argv): rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, *rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnrs = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr( ((pred_color - batch["pixels"])**2).mean()) print(f" PSNR = {psnr:.4f}") psnrs.append(float(psnr)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img( pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as pout: pout.write("{}".format(np.mean(np.array(psnrs)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step