def _next_test(self): """Sample next test example.""" idx = self.it self.it = (self.it + 1) % self.n_examples if self.render_path: return { "rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays) } else: return { "pixels": self.images[idx], "rays": utils.namedtuple_map(lambda r: r[idx], self.rays) }
def _train_init(self, args): """Initialize training.""" self._load_renderings(args) self._generate_rays() if args.batching == "all_images": # flatten the ray and image dimension together. self.images = self.images.reshape([-1, 3]) self.rays = utils.namedtuple_map( lambda r: r.reshape([-1, r.shape[-1]]), self.rays) elif args.batching == "single_image": self.images = self.images.reshape([-1, self.resolution, 3]) self.rays = utils.namedtuple_map( lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays) else: raise NotImplementedError( f"{args.batching} batching strategy is not implemented.")
def _next_train(self): """Sample next training batch.""" if self.batching == "all_images": ray_indices = np.random.randint(0, self.rays[0].shape[0], (self.batch_size, )) batch_pixels = self.images[ray_indices] batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays) elif self.batching == "single_image": image_index = np.random.randint(0, self.n_examples, ()) ray_indices = np.random.randint(0, self.rays[0][0].shape[0], (self.batch_size, )) batch_pixels = self.images[image_index][ray_indices] batch_rays = utils.namedtuple_map( lambda r: r[image_index][ray_indices], self.rays) else: raise NotImplementedError( f"{self.batching} batching strategy is not implemented.") return {"pixels": batch_pixels, "rays": batch_rays}
def construct_nerf(key, example_batch, args): """Construct a Neural Radiance Field. Args: key: jnp.ndarray. Random number generator. example_batch: dict, an example of a batch of data. args: FLAGS class. Hyperparameters of nerf. Returns: model: nn.Model. Nerf model with parameters. state: flax.Module.state. Nerf model state for stateful parameters. """ net_activation = nn.relu rgb_activation = nn.sigmoid sigma_activation = nn.relu # Assert that rgb_activation always produces outputs in [0, 1], and # sigma_activation always produce non-negative outputs. x = jnp.exp(jnp.linspace(-90, 90, 1024)) x = jnp.concatenate([-x[::-1], x], 0) rgb = rgb_activation(x) if jnp.any(rgb < 0) or jnp.any(rgb > 1): raise NotImplementedError( "Choice of rgb_activation `{}` produces colors outside of [0, 1]". format(args.rgb_activation)) sigma = sigma_activation(x) if jnp.any(sigma < 0): raise NotImplementedError( "Choice of sigma_activation `{}` produces negative densities". format(args.sigma_activation)) model = NerfModel(min_deg_point=args.min_deg_point, max_deg_point=args.max_deg_point, deg_view=args.deg_view, num_coarse_samples=args.num_coarse_samples, num_fine_samples=args.num_fine_samples, use_viewdirs=args.use_viewdirs, near=args.near, far=args.far, noise_std=args.noise_std, white_bkgd=args.white_bkgd, net_depth=args.net_depth, net_width=args.net_width, num_viewdir_channels=args.num_viewdir_channels, viewdir_net_depth=args.viewdir_net_depth, viewdir_net_width=args.viewdir_net_width, skip_layer=args.skip_layer, num_rgb_channels=args.num_rgb_channels, num_sigma_channels=args.num_sigma_channels, lindisp=args.lindisp, net_activation=net_activation, rgb_activation=rgb_activation, sigma_activation=sigma_activation, legacy_posenc_order=args.legacy_posenc_order) rays = example_batch["rays"] key1, key2, key3 = random.split(key, num=3) init_variables = model.init(key1, rng_0=key2, rng_1=key3, rays=utils.namedtuple_map( lambda x: x[0], rays), randomized=args.randomized) return model, init_variables