def apply(self, rng_0, rng_1, origins, directions, viewdirs, num_coarse_samples, num_fine_samples, use_viewdirs, near, far, noise_std, net_depth, net_width, net_depth_condition, net_width_condition, net_activation, skip_layer, num_rgb_channels, num_sigma_channels, randomized, white_bkgd, deg_point, deg_view, lindisp, rgb_activation, sigma_activation, legacy_posenc_order): """Nerf Model. Args: rng_0: jnp.ndarray, random number generator for coarse model sampling. rng_1: jnp.ndarray, random number generator for fine model sampling. origins: jnp.ndarray(float32), [batch_size, 3], each ray origin. directions: jnp.ndarray(float32), [batch_size, 3], each ray direction. viewdirs: jnp.ndarray(float32), [batch_size, 3], the viewing direction for each ray. This is only used if NDC rays are used, as otherwise `directions` is equal to viewdirs. num_coarse_samples: int, the number of samples for coarse nerf. num_fine_samples: int, the number of samples for fine nerf. use_viewdirs: bool, use viewdirs as a condition. near: float, near clip. far: float, far clip. noise_std: float, std dev of noise added to regularize sigma output. net_depth: int, the depth of the first part of MLP. net_width: int, the width of the first part of MLP. net_depth_condition: int, the depth of the second part of MLP. net_width_condition: int, the width of the second part of MLP. net_activation: function, the activation function used within the MLP. skip_layer: int, add a skip connection to the output vector of every skip_layer layers. num_rgb_channels: int, the number of RGB channels. num_sigma_channels: int, the number of density channels. randomized: bool, use randomized stratified sampling. white_bkgd: bool, use white background. deg_point: degree of positional encoding for positions. deg_view: degree of positional encoding for viewdirs. lindisp: bool, sampling linearly in disparity rather than depth if true. rgb_activation: function, the activation used to generate RGB. sigma_activation: function, the activation used to generate density. legacy_posenc_order: bool, keep the same ordering as the original tf code. Returns: ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)] """ mlp_fn = functools.partial(model_utils.MLP, net_depth=net_depth, net_width=net_width, net_depth_condition=net_depth_condition, net_width_condition=net_width_condition, net_activation=net_activation, skip_layer=skip_layer, num_rgb_channels=num_rgb_channels, num_sigma_channels=num_sigma_channels) # Stratified sampling along rays key, rng_0 = random.split(rng_0) z_vals, samples = model_utils.sample_along_rays( key, origins, directions, num_coarse_samples, near, far, randomized, lindisp) samples_enc = model_utils.posenc(samples, deg_point, legacy_posenc_order) # Point attribute predictions if use_viewdirs: viewdirs_enc = model_utils.posenc( viewdirs / jnp.linalg.norm(viewdirs, axis=-1, keepdims=True), deg_view, legacy_posenc_order) raw_rgb, raw_sigma = mlp_fn(samples_enc, viewdirs_enc) else: raw_rgb, raw_sigma = mlp_fn(samples_enc) # Add noises to regularize the density predictions if needed key, rng_0 = random.split(rng_0) raw_sigma = model_utils.add_gaussian_noise(key, raw_sigma, noise_std, randomized) rgb = rgb_activation(raw_rgb) sigma = sigma_activation(raw_sigma) # Volumetric rendering. comp_rgb, disp, acc, weights = model_utils.volumetric_rendering( rgb, sigma, z_vals, directions, white_bkgd=white_bkgd, ) ret = [ (comp_rgb, disp, acc), ] # Hierarchical sampling based on coarse predictions if num_fine_samples > 0: z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1]) key, rng_1 = random.split(rng_1) z_vals, samples = model_utils.sample_pdf( key, z_vals_mid, weights[Ellipsis, 1:-1], origins, directions, z_vals, num_fine_samples, randomized, ) samples_enc = model_utils.posenc(samples, deg_point, legacy_posenc_order) if use_viewdirs: raw_rgb, raw_sigma = mlp_fn(samples_enc, viewdirs_enc) else: raw_rgb, raw_sigma = mlp_fn(samples_enc) key, rng_1 = random.split(rng_1) raw_sigma = model_utils.add_gaussian_noise(key, raw_sigma, noise_std, randomized) rgb = rgb_activation(raw_rgb) sigma = sigma_activation(raw_sigma) comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering( rgb, sigma, z_vals, directions, white_bkgd=white_bkgd, ) ret.append((comp_rgb, disp, acc)) return ret
def __call__(self, rng_0, rng_1, rays, randomized): """Nerf Model. Args: rng_0: jnp.ndarray, random number generator for coarse model sampling. rng_1: jnp.ndarray, random number generator for fine model sampling. rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs. randomized: bool, use randomized stratified sampling. Returns: ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)] """ # Stratified sampling along rays key, rng_0 = random.split(rng_0) z_vals, samples = model_utils.sample_along_rays( key, rays.origins, rays.directions, self.num_coarse_samples, self.near, self.far, randomized, self.lindisp, ) samples_enc = model_utils.posenc( samples, self.min_deg_point, self.max_deg_point, self.legacy_posenc_order, ) # Construct the "coarse" MLP. coarse_mlp = model_utils.MLP( net_depth=self.net_depth, net_width=self.net_width, net_depth_condition=self.net_depth_condition, net_width_condition=self.net_width_condition, net_activation=self.net_activation, skip_layer=self.skip_layer, num_rgb_channels=self.num_rgb_channels, num_sigma_channels=self.num_sigma_channels) # Point attribute predictions if self.use_viewdirs: viewdirs_enc = model_utils.posenc( rays.viewdirs, 0, self.deg_view, self.legacy_posenc_order, ) raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc) else: raw_rgb, raw_sigma = coarse_mlp(samples_enc) # Add noises to regularize the density predictions if needed key, rng_0 = random.split(rng_0) raw_sigma = model_utils.add_gaussian_noise( key, raw_sigma, self.noise_std, randomized, ) rgb = self.rgb_activation(raw_rgb) sigma = self.sigma_activation(raw_sigma) # Volumetric rendering. comp_rgb, disp, acc, weights = model_utils.volumetric_rendering( rgb, sigma, z_vals, rays.directions, white_bkgd=self.white_bkgd, ) ret = [ (comp_rgb, disp, acc), ] # Hierarchical sampling based on coarse predictions if self.num_fine_samples > 0: z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1]) key, rng_1 = random.split(rng_1) z_vals, samples = model_utils.sample_pdf( key, z_vals_mid, weights[Ellipsis, 1:-1], rays.origins, rays.directions, z_vals, self.num_fine_samples, randomized, ) samples_enc = model_utils.posenc( samples, self.min_deg_point, self.max_deg_point, self.legacy_posenc_order, ) # Construct the "fine" MLP. fine_mlp = model_utils.MLP( net_depth=self.net_depth, net_width=self.net_width, net_depth_condition=self.net_depth_condition, net_width_condition=self.net_width_condition, net_activation=self.net_activation, skip_layer=self.skip_layer, num_rgb_channels=self.num_rgb_channels, num_sigma_channels=self.num_sigma_channels) if self.use_viewdirs: raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc) else: raw_rgb, raw_sigma = fine_mlp(samples_enc) key, rng_1 = random.split(rng_1) raw_sigma = model_utils.add_gaussian_noise( key, raw_sigma, self.noise_std, randomized, ) rgb = self.rgb_activation(raw_rgb) sigma = self.sigma_activation(raw_sigma) comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering( rgb, sigma, z_vals, rays.directions, white_bkgd=self.white_bkgd, ) ret.append((comp_rgb, disp, acc)) return ret
def apply(self, key_0, key_1, rays, n_samples, n_fine_samples, use_viewdirs, near, far, noise_std, net_depth, net_width, net_depth_condition, net_width_condition, activation, skip_layer, alpha_channel, rgb_channel, randomized, white_bkgd, deg_point, deg_view, lindisp): """Nerf Model. Args: key_0: jnp.ndarray, random number generator for coarse model sampling. key_1: jnp.ndarray, random number generator for fine model sampling. rays: jnp.ndarray(float32), [batch_size, 6/9], each ray is a 6-d vector where the first 3 dimensions represent the ray origin and the last 3 dimensions represent the unormalized ray direction. Note that if ndc rays are used, rays are 9-d where the extra 3-dimensional vector is the view direction before transformed to ndc rays. n_samples: int, the number of samples for coarse nerf. n_fine_samples: int, the number of samples for fine nerf. use_viewdirs: bool, use viewdirs as a condition. near: float, near clip. far: float, far clip. noise_std: float, std dev of noise added to regularize sigma output. net_depth: int, the depth of the first part of MLP. net_width: int, the width of the first part of MLP. net_depth_condition: int, the depth of the second part of MLP. net_width_condition: int, the width of the second part of MLP. activation: function, the activation function used in the MLP. skip_layer: int, add a skip connection to the output vector of every skip_layer layers. alpha_channel: int, the number of alpha_channels. rgb_channel: int, the number of rgb_channels. randomized: bool, use randomized stratified sampling. white_bkgd: bool, use white background. deg_point: degree of positional encoding for positions. deg_view: degree of positional encoding for viewdirs. lindisp: bool, sampling linearly in disparity rather than depth if true. Returns: ret: list, [(rgb, disp, acc), (rgb_coarse, disp_coarse, acc_coarse)] """ # Extract viewdirs from the ray array if rays.shape[-1] > 6: # viewdirs different from rays_d viewdirs = rays[Ellipsis, -3:] rays = rays[Ellipsis, :-3] else: # viewdirs are normalized rays_d viewdirs = rays[Ellipsis, 3:6] # Stratified sampling along rays z_vals, samples = model_utils.sample_along_rays(key_0, rays, n_samples, near, far, randomized, lindisp) samples = model_utils.posenc(samples, deg_point) # Point attribute predictions if use_viewdirs: norms = jnp.linalg.norm(viewdirs, axis=-1, keepdims=True) viewdirs = viewdirs / norms viewdirs = model_utils.posenc(viewdirs, deg_view) raw = model_utils.MLP( samples, viewdirs, net_depth=net_depth, net_width=net_width, net_depth_condition=net_depth_condition, net_width_condition=net_width_condition, activation=activation, skip_layer=skip_layer, alpha_channel=alpha_channel, rgb_channel=rgb_channel, ) else: raw = model_utils.MLP( samples, net_depth=net_depth, net_width=net_width, net_depth_condition=net_depth_condition, net_width_condition=net_width_condition, activation=activation, skip_layer=skip_layer, alpha_channel=alpha_channel, rgb_channel=rgb_channel, ) # Add noises to regularize the density predictions if needed raw = model_utils.noise_regularize(key_0, raw, noise_std, randomized) # Volumetric rendering. rgb, disp, acc, weights = model_utils.volumetric_rendering( raw, z_vals, rays[Ellipsis, 3:6], white_bkgd=white_bkgd, ) ret = [ (rgb, disp, acc), ] # Hierarchical sampling based on coarse predictions if n_fine_samples > 0: z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1]) z_vals, samples = model_utils.sample_pdf( key_1, z_vals_mid, weights[Ellipsis, 1:-1], rays, z_vals, n_fine_samples, randomized, ) samples = model_utils.posenc(samples, deg_point) if use_viewdirs: raw = model_utils.MLP(samples, viewdirs) else: raw = model_utils.MLP(samples) raw = model_utils.noise_regularize(key_1, raw, noise_std, randomized) rgb, disp, acc, unused_weights = model_utils.volumetric_rendering( raw, z_vals, rays[Ellipsis, 3:6], white_bkgd=white_bkgd, ) ret.append((rgb, disp, acc)) return ret