Exemplo n.º 1
0
    def apply(self, rng_0, rng_1, origins, directions, viewdirs,
              num_coarse_samples, num_fine_samples, use_viewdirs, near, far,
              noise_std, net_depth, net_width, net_depth_condition,
              net_width_condition, net_activation, skip_layer,
              num_rgb_channels, num_sigma_channels, randomized, white_bkgd,
              deg_point, deg_view, lindisp, rgb_activation, sigma_activation,
              legacy_posenc_order):
        """Nerf Model.

    Args:
      rng_0: jnp.ndarray, random number generator for coarse model sampling.
      rng_1: jnp.ndarray, random number generator for fine model sampling.
      origins: jnp.ndarray(float32), [batch_size, 3], each ray origin.
      directions: jnp.ndarray(float32), [batch_size, 3], each ray direction.
      viewdirs: jnp.ndarray(float32), [batch_size, 3], the viewing direction for
        each ray. This is only used if NDC rays are used, as otherwise
        `directions` is equal to viewdirs.
      num_coarse_samples: int, the number of samples for coarse nerf.
      num_fine_samples: int, the number of samples for fine nerf.
      use_viewdirs: bool, use viewdirs as a condition.
      near: float, near clip.
      far: float, far clip.
      noise_std: float, std dev of noise added to regularize sigma output.
      net_depth: int, the depth of the first part of MLP.
      net_width: int, the width of the first part of MLP.
      net_depth_condition: int, the depth of the second part of MLP.
      net_width_condition: int, the width of the second part of MLP.
      net_activation: function, the activation function used within the MLP.
      skip_layer: int, add a skip connection to the output vector of every
        skip_layer layers.
      num_rgb_channels: int, the number of RGB channels.
      num_sigma_channels: int, the number of density channels.
      randomized: bool, use randomized stratified sampling.
      white_bkgd: bool, use white background.
      deg_point: degree of positional encoding for positions.
      deg_view: degree of positional encoding for viewdirs.
      lindisp: bool, sampling linearly in disparity rather than depth if true.
      rgb_activation: function, the activation used to generate RGB.
      sigma_activation: function, the activation used to generate density.
      legacy_posenc_order: bool, keep the same ordering as the original tf code.

    Returns:
      ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
    """
        mlp_fn = functools.partial(model_utils.MLP,
                                   net_depth=net_depth,
                                   net_width=net_width,
                                   net_depth_condition=net_depth_condition,
                                   net_width_condition=net_width_condition,
                                   net_activation=net_activation,
                                   skip_layer=skip_layer,
                                   num_rgb_channels=num_rgb_channels,
                                   num_sigma_channels=num_sigma_channels)
        # Stratified sampling along rays
        key, rng_0 = random.split(rng_0)
        z_vals, samples = model_utils.sample_along_rays(
            key, origins, directions, num_coarse_samples, near, far,
            randomized, lindisp)
        samples_enc = model_utils.posenc(samples, deg_point,
                                         legacy_posenc_order)
        # Point attribute predictions
        if use_viewdirs:
            viewdirs_enc = model_utils.posenc(
                viewdirs / jnp.linalg.norm(viewdirs, axis=-1, keepdims=True),
                deg_view, legacy_posenc_order)
            raw_rgb, raw_sigma = mlp_fn(samples_enc, viewdirs_enc)
        else:
            raw_rgb, raw_sigma = mlp_fn(samples_enc)
        # Add noises to regularize the density predictions if needed
        key, rng_0 = random.split(rng_0)
        raw_sigma = model_utils.add_gaussian_noise(key, raw_sigma, noise_std,
                                                   randomized)
        rgb = rgb_activation(raw_rgb)
        sigma = sigma_activation(raw_sigma)
        # Volumetric rendering.
        comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
            rgb,
            sigma,
            z_vals,
            directions,
            white_bkgd=white_bkgd,
        )
        ret = [
            (comp_rgb, disp, acc),
        ]
        # Hierarchical sampling based on coarse predictions
        if num_fine_samples > 0:
            z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
            key, rng_1 = random.split(rng_1)
            z_vals, samples = model_utils.sample_pdf(
                key,
                z_vals_mid,
                weights[Ellipsis, 1:-1],
                origins,
                directions,
                z_vals,
                num_fine_samples,
                randomized,
            )
            samples_enc = model_utils.posenc(samples, deg_point,
                                             legacy_posenc_order)
            if use_viewdirs:
                raw_rgb, raw_sigma = mlp_fn(samples_enc, viewdirs_enc)
            else:
                raw_rgb, raw_sigma = mlp_fn(samples_enc)
            key, rng_1 = random.split(rng_1)
            raw_sigma = model_utils.add_gaussian_noise(key, raw_sigma,
                                                       noise_std, randomized)
            rgb = rgb_activation(raw_rgb)
            sigma = sigma_activation(raw_sigma)
            comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
                rgb,
                sigma,
                z_vals,
                directions,
                white_bkgd=white_bkgd,
            )
            ret.append((comp_rgb, disp, acc))
        return ret
Exemplo n.º 2
0
    def __call__(self, rng_0, rng_1, rays, randomized):
        """Nerf Model.

    Args:
      rng_0: jnp.ndarray, random number generator for coarse model sampling.
      rng_1: jnp.ndarray, random number generator for fine model sampling.
      rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
      randomized: bool, use randomized stratified sampling.

    Returns:
      ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
    """
        # Stratified sampling along rays
        key, rng_0 = random.split(rng_0)
        z_vals, samples = model_utils.sample_along_rays(
            key,
            rays.origins,
            rays.directions,
            self.num_coarse_samples,
            self.near,
            self.far,
            randomized,
            self.lindisp,
        )
        samples_enc = model_utils.posenc(
            samples,
            self.min_deg_point,
            self.max_deg_point,
            self.legacy_posenc_order,
        )

        # Construct the "coarse" MLP.
        coarse_mlp = model_utils.MLP(
            net_depth=self.net_depth,
            net_width=self.net_width,
            net_depth_condition=self.net_depth_condition,
            net_width_condition=self.net_width_condition,
            net_activation=self.net_activation,
            skip_layer=self.skip_layer,
            num_rgb_channels=self.num_rgb_channels,
            num_sigma_channels=self.num_sigma_channels)

        # Point attribute predictions
        if self.use_viewdirs:
            viewdirs_enc = model_utils.posenc(
                rays.viewdirs,
                0,
                self.deg_view,
                self.legacy_posenc_order,
            )
            raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc)
        else:
            raw_rgb, raw_sigma = coarse_mlp(samples_enc)
        # Add noises to regularize the density predictions if needed
        key, rng_0 = random.split(rng_0)
        raw_sigma = model_utils.add_gaussian_noise(
            key,
            raw_sigma,
            self.noise_std,
            randomized,
        )
        rgb = self.rgb_activation(raw_rgb)
        sigma = self.sigma_activation(raw_sigma)
        # Volumetric rendering.
        comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
            rgb,
            sigma,
            z_vals,
            rays.directions,
            white_bkgd=self.white_bkgd,
        )
        ret = [
            (comp_rgb, disp, acc),
        ]
        # Hierarchical sampling based on coarse predictions
        if self.num_fine_samples > 0:
            z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
            key, rng_1 = random.split(rng_1)
            z_vals, samples = model_utils.sample_pdf(
                key,
                z_vals_mid,
                weights[Ellipsis, 1:-1],
                rays.origins,
                rays.directions,
                z_vals,
                self.num_fine_samples,
                randomized,
            )
            samples_enc = model_utils.posenc(
                samples,
                self.min_deg_point,
                self.max_deg_point,
                self.legacy_posenc_order,
            )

            # Construct the "fine" MLP.
            fine_mlp = model_utils.MLP(
                net_depth=self.net_depth,
                net_width=self.net_width,
                net_depth_condition=self.net_depth_condition,
                net_width_condition=self.net_width_condition,
                net_activation=self.net_activation,
                skip_layer=self.skip_layer,
                num_rgb_channels=self.num_rgb_channels,
                num_sigma_channels=self.num_sigma_channels)

            if self.use_viewdirs:
                raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc)
            else:
                raw_rgb, raw_sigma = fine_mlp(samples_enc)
            key, rng_1 = random.split(rng_1)
            raw_sigma = model_utils.add_gaussian_noise(
                key,
                raw_sigma,
                self.noise_std,
                randomized,
            )
            rgb = self.rgb_activation(raw_rgb)
            sigma = self.sigma_activation(raw_sigma)
            comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
                rgb,
                sigma,
                z_vals,
                rays.directions,
                white_bkgd=self.white_bkgd,
            )
            ret.append((comp_rgb, disp, acc))
        return ret
Exemplo n.º 3
0
  def apply(self, key_0, key_1, rays, n_samples, n_fine_samples, use_viewdirs,
            near, far, noise_std, net_depth, net_width, net_depth_condition,
            net_width_condition, activation, skip_layer, alpha_channel,
            rgb_channel, randomized, white_bkgd, deg_point, deg_view, lindisp):
    """Nerf Model.

    Args:
      key_0: jnp.ndarray, random number generator for coarse model sampling.
      key_1: jnp.ndarray, random number generator for fine model sampling.
      rays: jnp.ndarray(float32), [batch_size, 6/9], each ray is a 6-d vector
        where the first 3 dimensions represent the ray origin and the last 3
        dimensions represent the unormalized ray direction. Note that if ndc
        rays are used, rays are 9-d where the extra 3-dimensional vector is the
        view direction before transformed to ndc rays.
      n_samples: int, the number of samples for coarse nerf.
      n_fine_samples: int, the number of samples for fine nerf.
      use_viewdirs: bool, use viewdirs as a condition.
      near: float, near clip.
      far: float, far clip.
      noise_std: float, std dev of noise added to regularize sigma output.
      net_depth: int, the depth of the first part of MLP.
      net_width: int, the width of the first part of MLP.
      net_depth_condition: int, the depth of the second part of MLP.
      net_width_condition: int, the width of the second part of MLP.
      activation: function, the activation function used in the MLP.
      skip_layer: int, add a skip connection to the output vector of every
        skip_layer layers.
      alpha_channel: int, the number of alpha_channels.
      rgb_channel: int, the number of rgb_channels.
      randomized: bool, use randomized stratified sampling.
      white_bkgd: bool, use white background.
      deg_point: degree of positional encoding for positions.
      deg_view: degree of positional encoding for viewdirs.
      lindisp: bool, sampling linearly in disparity rather than depth if true.

    Returns:
      ret: list, [(rgb, disp, acc), (rgb_coarse, disp_coarse, acc_coarse)]
    """
    # Extract viewdirs from the ray array
    if rays.shape[-1] > 6:  # viewdirs different from rays_d
      viewdirs = rays[Ellipsis, -3:]
      rays = rays[Ellipsis, :-3]
    else:  # viewdirs are normalized rays_d
      viewdirs = rays[Ellipsis, 3:6]
    # Stratified sampling along rays
    z_vals, samples = model_utils.sample_along_rays(key_0, rays, n_samples,
                                                    near, far, randomized,
                                                    lindisp)
    samples = model_utils.posenc(samples, deg_point)
    # Point attribute predictions
    if use_viewdirs:
      norms = jnp.linalg.norm(viewdirs, axis=-1, keepdims=True)
      viewdirs = viewdirs / norms
      viewdirs = model_utils.posenc(viewdirs, deg_view)
      raw = model_utils.MLP(
          samples, viewdirs, net_depth=net_depth, net_width=net_width,
          net_depth_condition=net_depth_condition,
          net_width_condition=net_width_condition,
          activation=activation, skip_layer=skip_layer,
          alpha_channel=alpha_channel, rgb_channel=rgb_channel,
      )
    else:
      raw = model_utils.MLP(
          samples, net_depth=net_depth, net_width=net_width,
          net_depth_condition=net_depth_condition,
          net_width_condition=net_width_condition,
          activation=activation, skip_layer=skip_layer,
          alpha_channel=alpha_channel, rgb_channel=rgb_channel,
      )
    # Add noises to regularize the density predictions if needed
    raw = model_utils.noise_regularize(key_0, raw, noise_std, randomized)
    # Volumetric rendering.
    rgb, disp, acc, weights = model_utils.volumetric_rendering(
        raw,
        z_vals,
        rays[Ellipsis, 3:6],
        white_bkgd=white_bkgd,
    )
    ret = [
        (rgb, disp, acc),
    ]
    # Hierarchical sampling based on coarse predictions
    if n_fine_samples > 0:
      z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
      z_vals, samples = model_utils.sample_pdf(
          key_1,
          z_vals_mid,
          weights[Ellipsis, 1:-1],
          rays,
          z_vals,
          n_fine_samples,
          randomized,
      )
      samples = model_utils.posenc(samples, deg_point)
      if use_viewdirs:
        raw = model_utils.MLP(samples, viewdirs)
      else:
        raw = model_utils.MLP(samples)
      raw = model_utils.noise_regularize(key_1, raw, noise_std, randomized)
      rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
          raw,
          z_vals,
          rays[Ellipsis, 3:6],
          white_bkgd=white_bkgd,
      )
      ret.append((rgb, disp, acc))
    return ret