def save_img(img, pth): """Save an image to disk. Args: img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1] before saved to pth. pth: string, path to save the image to. """ with file_utils.open_file(pth, "wb") as imgout: Image.fromarray(np.array( (np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
def _load_images(self, imgdir): """Function to load images.""" if not file_utils.file_exists(imgdir): raise ValueError("Image folder {} doesn't exist.".format(imgdir)) imgfiles = [ path.join(imgdir, f) for f in sorted(file_utils.listdir(imgdir)) if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") ] images = [] for imgfile in tqdm.tqdm(imgfiles): with file_utils.open_file(imgfile, "rb") as imgin: image = np.array(Image.open(imgin), dtype=np.uint8) images.append(image) images = np.stack(images, axis=-1) return images
def _get_suffix(self, args): """Get the suffix for reshaped image directory.""" imgdir_suffix = "" if args.dataset.factor > 0: imgdir_suffix = "_{}".format(int(args.dataset.factor)) factor = args.dataset.factor elif args.dataset.image_height > 0: img_path = path.join(args.dataset.data_dir, "images") img0 = [ path.join(args.dataset.data_dir, "images", f) for f in sorted(file_utils.listdir(img_path)) if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") ][0] with file_utils.open_file(img0, "rb") as imgin: sh = np.array(Image.open(imgin), dtype=np.uint8).shape factor = sh[0] / float(args.dataset.image_height) width = int(sh[1] / factor) imgdir_suffix = "_{}x{}".format(width, args.dataset.image_height) else: factor = 1 return imgdir_suffix, factor
def evaluate(config, workdir): """Evalution function.""" # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = jax.random.PRNGKey(config.seed) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, test_ds = datasets.create_dataset(config) example_batch = train_ds.peek() rng, key = jax.random.split(rng) #---------------------------------------------------------------------------- # Initialize model. learning_rate_fn = train_utils.create_learning_rate_fn(config) model, state = models.create_train_state(config, key, learning_rate_fn=learning_rate_fn, example_batch=example_batch) #---------------------------------------------------------------------------- # Get the rendering function. Renderig is forced ot be deterministic even if # trainin is randomized render_pfn = render_utils.get_render_function(model, config, randomized=False) last_step = 0 out_dir = os.path.join( workdir, "path_renders" if config.dataset.render_path else "test_preds") if not config.eval.eval_once: # Prepare Metric Writers summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) while True: state = checkpoints.restore_checkpoint(workdir, state) step = int(state.step) if step <= last_step: continue if config.eval.save_output and (not file_utils.isdir(out_dir)): file_utils.makedirs(out_dir) psnr_values = [] ssim_values = [] if not config.eval.eval_once: showcase_index = np.random.randint(0, test_ds.size) for idx in range(test_ds.size): logging.info("Evaluating [%d / %d].", idx, test_ds.size) batch = next(test_ds) test_pixels = batch.target_view.rgb if test_pixels is not None: test_pixels = model_utils.uint2float(test_pixels) #----------------------------------------------------------- # Render Image variables = {"params": state.params} pred_color, pred_disp, pred_acc = render_utils.render_image( functools.partial(render_pfn, variables), batch, rng, render_utils.normalize_disp(config.dataset.name), chunk=config.eval.chunk) if jax.process_index() != 0: continue #----------------------------------------------------------- # Get showcase example for logging if not config.eval.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not config.dataset.render_path: showcase_gt = test_pixels #----------------------------------------------------------- # If get pixels available, evaluate if not config.dataset.render_path: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weights=True) logging.info(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") # pylint: disable=logging-format-interpolation psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) #----------------------------------------------------------- # Save generated image if config.eval.save_output: model_utils.save_img( pred_color, os.path.join(out_dir, "{:03d}.png".format(idx))) if pred_disp is not None: model_utils.save_img( pred_disp[Ellipsis, 0], os.path.join(out_dir, "disp_{:03d}.png".format(idx))) #----------------------------------------------------------- if (not config.eval.eval_once) and (jax.process_index() == 0): summary_writer.image("pred_color", showcase_color, step) if showcase_disp is not None: summary_writer.image("pred_disp", showcase_disp, step) if showcase_acc is not None: summary_writer.image("pred_acc", showcase_acc, step) if not config.dataset.render_path: summary_writer.scalar("eval_metric/psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("eval_metric/ssim", np.mean(np.array(ssim_values)), step) summary_writer.image("target", showcase_gt, step) #----------------------------------------------------------- # Save the metric to file if config.eval.save_output and (not config.dataset.render_path) and ( jax.process_index() == 0): with file_utils.open_file( os.path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with file_utils.open_file( os.path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with file_utils.open_file(os.path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with file_utils.open_file(os.path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) if config.eval.eval_once: break if int(step) >= config.train.max_steps: break last_step = step logging.info("Finishing evaluation at step %d", last_step)
def _load_renderings(self, args): """Load images and camera information.""" #------------------------------------------- # Load images. #------------------------------------------- imgdir_suffix, factor = self._get_suffix(args) imgdir = path.join(args.dataset.data_dir, "images" + imgdir_suffix) images = self._load_images(imgdir) #------------------------------------------- # Load poses and bds. #------------------------------------------- with file_utils.open_file( path.join(args.dataset.data_dir, "poses_bounds.npy"), "rb") as fp: poses_arr = np.load(fp) # Get the intrinsic matrix with file_utils.open_file( path.join(args.dataset.data_dir, "hwf_cxcy.npy"), "rb") as fp: self.intrinsic_arr = np.load(fp) # Update the intrinsic matix to accounto for resizing self.intrinsic_arr = self.intrinsic_arr * 1. / factor # poses_arr contains an array consisting of a 3x4 pose matrices and # 2 depth bounds for each image. The pose matrix contain [R t] as the # left 3x4 matrix # pose_arr has shape (...,14) {3x4 + 2} poses = poses_arr[:, :-2].reshape([-1, 3, 4]).transpose([1, 2, 0]) bds = poses_arr[:, -2:].transpose([1, 0]) # Convert R matrix from the form [down right back] to [right up back] poses = np.concatenate( [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) # Transpose such that the first dimension is number of images images = np.moveaxis(images, -1, 0) poses = np.moveaxis(poses, -1, 0).astype(np.float32) bds = np.moveaxis(bds, -1, 0).astype(np.float32) scale = 1. / (bds.min() * .75) poses[:, :3, 3] *= scale bds *= scale poses = self._recenter_poses(poses) # Get the min and max depth of the scene self.min_depth = bds.min() self.max_depth = bds.max() # Use this to set the near and far plane args.model.near = self.min_depth.item() args.model.far = self.max_depth.item() if self.split == "test": self._generate_spiral_poses(poses, bds) # Select the split. i_test = np.arange(images.shape[0])[::args.dataset.llffhold] i_train = np.array( [i for i in np.arange(int(images.shape[0])) if i not in i_test]) if self.split == "train": indices = i_train else: indices = i_test images = images[indices] poses = poses[indices] self.images = images self.camtoworlds = poses[:, :3, :4] # intrinsic arr has H, W, fx, fy, cx, cy self.focal = self.intrinsic_arr[2][0] self.h, self.w = images.shape[1:3] self.resolution = self.h * self.w if args.dataset.render_path and self.split == "test": self.n_examples = self.render_poses.shape[0] else: self.n_examples = images.shape[0] _, _, fx, fy, cx, cy = self.intrinsic_arr[:, 0] self.intrinsic_matrix = np.array([[fx, 0, -cx, 0], [0, -fy, -cy, 0], [0, 0, 1, 0]]).astype(np.float32)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ if config.dataset.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") tf.io.gfile.makedirs(workdir) # Deterministic training. rng = jax.random.PRNGKey(config.seed) # Shift the numpy random seed by process_index() to shuffle data loaded # by different hosts np.random.seed(20201473 + jax.process_index()) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, eval_ds = datasets.create_dataset(config) example_batch = train_ds.peek() #---------------------------------------------------------------------------- # Learning rate schedule. num_train_steps = config.train.max_steps if num_train_steps == -1: num_train_steps = train_ds.size() steps_per_epoch = num_train_steps // config.train.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) learning_rate_fn = train_utils.create_learning_rate_fn(config) #---------------------------------------------------------------------------- # Initialize model. rng, model_rng = jax.random.split(rng) model, state = models.create_train_state( config, model_rng, learning_rate_fn=learning_rate_fn, example_batch=example_batch, ) #---------------------------------------------------------------------------- # Set up checkpointing of the model and the input pipeline. state = checkpoints.restore_checkpoint(workdir, state) initial_step = int(state.step) + 1 #---------------------------------------------------------------------------- # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.train.weight_decay, config=config, ), axis_name="batch", ) # Get distributed rendering function render_pfn = render_utils.get_render_function( model=model, config=config, randomized=False, # No randomization for evaluation. ) #---------------------------------------------------------------------------- # Prepare Metric Writers writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, ] train_metrics = None # Prefetch_buffer_size = 6 x batch_size ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6) n_local_devices = jax.local_device_count() rng = rng + jax.process_index() # Make random seed separate across hosts. keys = jax.random.split(rng, n_local_devices) # For pmapping RNG keys. with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = next(ptrain_ds) state, metrics_update, keys = p_train_step(rng=keys, state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.train.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None if step % config.train.render_every_steps == 0 or is_last_step: test_batch = next(eval_ds) test_pixels = model_utils.uint2float( test_batch.target_view.rgb) # extract for evaluation with report_progress.timed("eval"): pred_color, pred_disp, pred_acc = eval_step( state, keys[0], test_batch, render_pfn, config) #------------------------------------------------------------------ # Log metrics and images for host 0 #------------------------------------------------------------------ if jax.process_index() == 0: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weight=True) writer.write_scalars( step, { "train_eval/test_psnr": psnr, "train_eval/test_ssim": ssim, }) writer.write_images( step, { "test_pred_color": pred_color[None, :], "test_target": test_pixels[None, :] }) if pred_disp is not None: writer.write_images( step, {"test_pred_disp": pred_disp[None, :]}) if pred_acc is not None: writer.write_images( step, {"test_pred_acc": pred_acc[None, :]}) #------------------------------------------------------------------ if (jax.process_index() == 0) and (step % config.train.checkpoint_every_steps == 0 or is_last_step): # Write final metrics to file with file_utils.open_file( os.path.join(workdir, "train_logs.json"), "w") as f: log_dict = metric_update.compute() for k, v in log_dict.items(): log_dict[k] = v.item() f.write(json.dumps(log_dict)) with report_progress.timed("checkpoint"): state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100) logging.info("Finishing training at step %d", num_train_steps)