Example #1
0
  def _next_test(self):
    """Sample next test example."""
    idx = self.it
    self.it = (self.it + 1) % self.n_examples

    if self.render_path:
      return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)}
    else:
      return {
          "pixels": self.images[idx],
          "rays": utils.namedtuple_map(lambda r: r[idx], self.rays)
      }
Example #2
0
    def _train_init(self, args):
        """Initialize training."""
        self._load_renderings(args)
        self._generate_rays()

        if args.image_batching:
            # flatten the ray and image dimension together.
            self.images = self.images.reshape([-1, 3])
            self.rays = utils.namedtuple_map(
                lambda r: r.reshape([-1, r.shape[-1]]), self.rays)
        else:
            self.images = self.images.reshape([-1, self.resolution, 3])
            self.rays = utils.namedtuple_map(
                lambda r: r.reshape([-1, self.resolution, r.shape[-1]]),
                self.rays)
Example #3
0
  def _train_init(self, args):
    """Initialize training."""
    self._load_renderings(args)
    self._generate_rays()

    if args.batching == "all_images":
      # flatten the ray and image dimension together.
      self.images = self.images.reshape([-1, 3])
      self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]),
                                       self.rays)
    elif args.batching == "single_image":
      self.images = self.images.reshape([-1, self.resolution, 3])
      self.rays = utils.namedtuple_map(
          lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays)
    else:
      raise NotImplementedError(
          f"{args.batching} batching strategy is not implemented.")
Example #4
0
    def _next_train(self):
        """Sample next training batch."""

        if self.image_batching:
            ray_indices = np.random.randint(0, self.rays[0].shape[0],
                                            (self.batch_size, ))
            batch_pixels = self.images[ray_indices]
            batch_rays = utils.namedtuple_map(lambda r: r[ray_indices],
                                              self.rays)
        else:
            image_index = np.random.randint(0, self.n_examples, ())
            ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
                                            (self.batch_size, ))
            batch_pixels = self.images[image_index][ray_indices]
            batch_rays = utils.namedtuple_map(
                lambda r: r[image_index][ray_indices], self.rays)
        return {"pixels": batch_pixels, "rays": batch_rays}
Example #5
0
  def _next_train(self):
    """Sample next training batch."""

    if self.batching == "all_images":
      ray_indices = np.random.randint(0, self.rays[0].shape[0],
                                      (self.batch_size,))
      batch_pixels = self.images[ray_indices]
      batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays)
    elif self.batching == "single_image":
      image_index = np.random.randint(0, self.n_examples, ())
      ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
                                      (self.batch_size,))
      batch_pixels = self.images[image_index][ray_indices]
      batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices],
                                        self.rays)
    else:
      raise NotImplementedError(
          f"{self.batching} batching strategy is not implemented.")
    return {"pixels": batch_pixels, "rays": batch_rays}
Example #6
0
def construct_nerf(key, example_batch, args):
    """Construct a Neural Radiance Field.

  Args:
    key: jnp.ndarray. Random number generator.
    example_batch: dict, an example of a batch of data.
    args: FLAGS class. Hyperparameters of nerf.

  Returns:
    model: nn.Model. Nerf model with parameters.
    state: flax.Module.state. Nerf model state for stateful parameters.
  """
    net_activation = getattr(nn, str(args.net_activation))
    rgb_activation = getattr(nn, str(args.rgb_activation))
    sigma_activation = getattr(nn, str(args.sigma_activation))

    # Assert that rgb_activation always produces outputs in [0, 1], and
    # sigma_activation always produce non-negative outputs.
    x = jnp.exp(jnp.linspace(-90, 90, 1024))
    x = jnp.concatenate([-x[::-1], x], 0)

    rgb = rgb_activation(x)
    if jnp.any(rgb < 0) or jnp.any(rgb > 1):
        raise NotImplementedError(
            "Choice of rgb_activation `{}` produces colors outside of [0, 1]".
            format(args.rgb_activation))

    sigma = sigma_activation(x)
    if jnp.any(sigma < 0):
        raise NotImplementedError(
            "Choice of sigma_activation `{}` produces negative densities".
            format(args.sigma_activation))

    model = NerfModel(min_deg_point=args.min_deg_point,
                      max_deg_point=args.max_deg_point,
                      deg_view=args.deg_view,
                      num_coarse_samples=args.num_coarse_samples,
                      num_fine_samples=args.num_fine_samples,
                      use_viewdirs=args.use_viewdirs,
                      near=args.near,
                      far=args.far,
                      noise_std=args.noise_std,
                      white_bkgd=args.white_bkgd,
                      net_depth=args.net_depth,
                      net_width=args.net_width,
                      net_depth_condition=args.net_depth_condition,
                      net_width_condition=args.net_width_condition,
                      skip_layer=args.skip_layer,
                      num_rgb_channels=args.num_rgb_channels,
                      num_sigma_channels=args.num_sigma_channels,
                      lindisp=args.lindisp,
                      net_activation=net_activation,
                      rgb_activation=rgb_activation,
                      sigma_activation=sigma_activation,
                      legacy_posenc_order=args.legacy_posenc_order)
    rays = example_batch["rays"]
    key1, key2, key3 = random.split(key, num=3)

    init_variables = model.init(key1,
                                rng_0=key2,
                                rng_1=key3,
                                rays=utils.namedtuple_map(
                                    lambda x: x[0], rays),
                                randomized=args.randomized)

    return model, init_variables