Esempio n. 1
0
def viewdir_fn(viewdir_mlp_model, viewdir_mlp_params, rgb_features, viewdirs,
               scene_params):
  """Calls the per-ray view-dependence MLP to compute color residuals.

  Args:
    viewdir_mlp_model: A nerf.model_utils.MLP that predicts the per-ray
      view-dependent residual color.
    viewdir_mlp_params: A dict containing the MLP parameters for the per-ray
      view-dependence MLP.
    rgb_features:  A [H, W, 7] JAX tensor containing the composited color and
      computed feature vector for each ray.
    viewdirs: A [H, W, 3] JAX tensor of ray directions.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).

  Returns:
    A [H, W, 3] JAX tensor with the view-dependent RGB residual for each ray.
  """
  rgb_activation = nn.sigmoid
  viewdirs_enc = model_utils.posenc(viewdirs, 0, scene_params['_deg_view'],
                                    scene_params['_legacy_posenc_order'])
  viewdirs_enc_features = jnp.concatenate([viewdirs_enc, rgb_features], axis=-1)
  viewdirs_enc_features = jnp.expand_dims(viewdirs_enc_features, -2)
  raw_rgb, _ = viewdir_mlp_model.apply(viewdir_mlp_params,
                                       viewdirs_enc_features)
  return rgb_activation(raw_rgb)
Esempio n. 2
0
 def inner_pmap(params, samples):
   """We need an inner function as only JAX types can be passed to a pmap."""
   samples_enc = model_utils.posenc(samples, scene_params['_min_deg_point'],
                                    scene_params['_max_deg_point'],
                                    scene_params['_legacy_posenc_order'])
   raw_rgb_features, raw_sigma = mlp_model.apply(params, samples_enc)
   rgb_features = rgb_activation(raw_rgb_features)
   sigma = sigma_activation(raw_sigma)
   return lax.all_gather((rgb_features, sigma), axis_name='batch')
Esempio n. 3
0
    def __call__(self, rng_0, rng_1, rays, randomized):
        """Nerf Model.

    Args:
      rng_0: jnp.ndarray, random number generator for coarse model sampling.
      rng_1: jnp.ndarray, random number generator for fine model sampling.
      rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
      randomized: bool, use randomized stratified sampling.

    Returns:
      ret: list, [(rgb_coarse, disp_coarse, acc_coarse, features_coarse,
      specular_coarse), (rgb, disp, acc, features, specular)]
    """
        # Stratified sampling along rays
        key, rng_0 = random.split(rng_0)
        z_vals, coarse_samples = model_utils.sample_along_rays(
            key,
            rays.origins,
            rays.directions,
            self.num_coarse_samples,
            self.near,
            self.far,
            randomized,
            self.lindisp,
        )
        coarse_samples_enc = model_utils.posenc(
            coarse_samples,
            self.min_deg_point,
            self.max_deg_point,
            self.legacy_posenc_order,
        )

        # Construct the "coarse" MLP.
        coarse_mlp = model_utils.MLP(
            net_depth=self.net_depth,
            net_width=self.net_width,
            net_activation=self.net_activation,
            skip_layer=self.skip_layer,
            num_rgb_channels=self.num_rgb_channels + self.num_viewdir_channels,
            num_sigma_channels=self.num_sigma_channels)

        # Point attribute predictions
        if self.use_viewdirs:
            viewdirs_enc = model_utils.posenc(
                rays.viewdirs,
                0,
                self.deg_view,
                self.legacy_posenc_order,
            )
            raw_features_and_rgb, raw_sigma = coarse_mlp(coarse_samples_enc)
        else:
            raw_rgb, raw_sigma = coarse_mlp(coarse_samples_enc)

        # Add noises to regularize the density predictions if needed
        key, rng_0 = random.split(rng_0)
        raw_sigma = model_utils.add_gaussian_noise(
            key,
            raw_sigma,
            self.noise_std,
            randomized,
        )
        sigma = self.sigma_activation(raw_sigma)

        if self.use_viewdirs:
            coarse_viewdir_mlp = model_utils.MLP(
                net_depth=self.viewdir_net_depth,
                net_width=self.viewdir_net_width,
                net_activation=self.net_activation,
                skip_layer=self.skip_layer,
                num_rgb_channels=self.num_rgb_channels,
                num_sigma_channels=self.num_sigma_channels)

            # Overcomposite the features to get an encoding for the features.
            comp_features, _, _, _ = model_utils.volumetric_rendering(
                raw_features_and_rgb[Ellipsis, self.num_rgb_channels:(
                    self.num_rgb_channels + self.num_viewdir_channels)],
                sigma,
                z_vals,
                rays.directions,
                white_bkgd=False,
            )
            features = comp_features[Ellipsis, 0:self.num_rgb_channels]

            diffuse_rgb = self.rgb_activation(
                raw_features_and_rgb[Ellipsis, 0:self.num_rgb_channels])
            comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
                diffuse_rgb,
                sigma,
                z_vals,
                rays.directions,
                white_bkgd=self.white_bkgd,
            )

            viewdirs_enc_features = jnp.concatenate(
                [viewdirs_enc, comp_rgb, comp_features], axis=-1)
            viewdirs_enc_features = jnp.expand_dims(viewdirs_enc_features, -2)
            raw_comp_rgb_residual, _ = coarse_viewdir_mlp(
                viewdirs_enc_features)

            output_shape = list(comp_features.shape)
            output_shape[-1] = 3
            raw_comp_rgb_residual = raw_comp_rgb_residual.reshape(output_shape)
            rgb_residual = self.rgb_activation(raw_comp_rgb_residual)
            comp_rgb += rgb_residual
        else:
            rgb = self.rgb_activation(raw_rgb)
            # Volumetric rendering.
            comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
                rgb,
                sigma,
                z_vals,
                rays.directions,
                white_bkgd=self.white_bkgd,
            )
            features = jnp.zeros_like(comp_rgb)
            rgb_residual = jnp.zeros_like(comp_rgb)

        ret = [
            (comp_rgb, disp, acc, sigma, features, rgb_residual),
        ]
        # Hierarchical sampling based on coarse predictions
        if self.num_fine_samples > 0:
            z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
            key, rng_1 = random.split(rng_1)
            z_vals, fine_samples = model_utils.sample_pdf(
                key,
                z_vals_mid,
                weights[Ellipsis, 1:-1],
                rays.origins,
                rays.directions,
                z_vals,
                self.num_fine_samples,
                randomized,
            )
            fine_samples_enc = model_utils.posenc(
                fine_samples,
                self.min_deg_point,
                self.max_deg_point,
                self.legacy_posenc_order,
            )

            # Construct the "fine" MLP.
            fine_mlp = model_utils.MLP(
                net_depth=self.net_depth,
                net_width=self.net_width,
                net_activation=self.net_activation,
                skip_layer=self.skip_layer,
                num_rgb_channels=self.num_rgb_channels +
                self.num_viewdir_channels,
                num_sigma_channels=self.num_sigma_channels)

            if self.use_viewdirs:
                raw_features_and_rgb, raw_sigma = fine_mlp(fine_samples_enc)
            else:
                raw_rgb, raw_sigma = fine_mlp(fine_samples_enc)

            key, rng_1 = random.split(rng_1)
            raw_sigma = model_utils.add_gaussian_noise(
                key,
                raw_sigma,
                self.noise_std,
                randomized,
            )
            sigma = self.sigma_activation(raw_sigma)

            _, raw_reg_sigma = fine_mlp(coarse_samples_enc)
            reg_sigma = self.sigma_activation(raw_reg_sigma)

            if self.use_viewdirs:
                fine_viewdir_mlp = model_utils.MLP(
                    net_depth=self.viewdir_net_depth,
                    net_width=self.viewdir_net_width,
                    net_activation=self.net_activation,
                    skip_layer=self.skip_layer,
                    num_rgb_channels=self.num_rgb_channels,
                    num_sigma_channels=self.num_sigma_channels)

                # Overcomposite the features to get an encoding for the features.
                features_and_rgb = self.rgb_activation(raw_features_and_rgb)
                features = features_and_rgb[Ellipsis, self.num_rgb_channels:(
                    self.num_rgb_channels + self.num_viewdir_channels)]

                comp_features, _, _, _ = model_utils.volumetric_rendering(
                    features,
                    sigma,
                    z_vals,
                    rays.directions,
                    white_bkgd=False,
                )
                features = comp_features[Ellipsis, 0:self.num_rgb_channels]

                diffuse_rgb = features_and_rgb[Ellipsis,
                                               0:self.num_rgb_channels]
                comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
                    diffuse_rgb,
                    sigma,
                    z_vals,
                    rays.directions,
                    white_bkgd=self.white_bkgd,
                )

                viewdirs_enc_features = jnp.concatenate(
                    [viewdirs_enc, comp_rgb, comp_features], axis=-1)
                viewdirs_enc_features = jnp.expand_dims(
                    viewdirs_enc_features, -2)
                raw_comp_rgb_residual, _ = fine_viewdir_mlp(
                    viewdirs_enc_features)

                output_shape = list(comp_features.shape)
                output_shape[-1] = 3
                raw_comp_rgb_residual = raw_comp_rgb_residual.reshape(
                    output_shape)
                rgb_residual = self.rgb_activation(raw_comp_rgb_residual)
                comp_rgb += rgb_residual
            else:
                rgb = self.rgb_activation(raw_rgb)
                # Volumetric rendering.
                comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
                    rgb,
                    sigma,
                    z_vals,
                    rays.directions,
                    white_bkgd=self.white_bkgd,
                )
                features = jnp.zeros_like(comp_rgb)
                rgb_residual = jnp.zeros_like(comp_rgb)

            ret.append(
                (comp_rgb, disp, acc, reg_sigma, features, rgb_residual))
        return ret