示例#1
0
    def _next_test(self):
        """Sample next test example."""
        idx = self.it
        self.it = (self.it + 1) % self.n_examples

        if self.render_path:
            return {
                "rays": utils.namedtuple_map(lambda r: r[idx],
                                             self.render_rays)
            }
        else:
            return {
                "pixels": self.images[idx],
                "rays": utils.namedtuple_map(lambda r: r[idx], self.rays)
            }
示例#2
0
    def _train_init(self, args):
        """Initialize training."""
        self._load_renderings(args)
        self._generate_rays()

        if args.batching == "all_images":
            # flatten the ray and image dimension together.
            self.images = self.images.reshape([-1, 3])
            self.rays = utils.namedtuple_map(
                lambda r: r.reshape([-1, r.shape[-1]]), self.rays)
        elif args.batching == "single_image":
            self.images = self.images.reshape([-1, self.resolution, 3])
            self.rays = utils.namedtuple_map(
                lambda r: r.reshape([-1, self.resolution, r.shape[-1]]),
                self.rays)
        else:
            raise NotImplementedError(
                f"{args.batching} batching strategy is not implemented.")
示例#3
0
    def _next_train(self):
        """Sample next training batch."""

        if self.batching == "all_images":
            ray_indices = np.random.randint(0, self.rays[0].shape[0],
                                            (self.batch_size, ))
            batch_pixels = self.images[ray_indices]
            batch_rays = utils.namedtuple_map(lambda r: r[ray_indices],
                                              self.rays)
        elif self.batching == "single_image":
            image_index = np.random.randint(0, self.n_examples, ())
            ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
                                            (self.batch_size, ))
            batch_pixels = self.images[image_index][ray_indices]
            batch_rays = utils.namedtuple_map(
                lambda r: r[image_index][ray_indices], self.rays)
        else:
            raise NotImplementedError(
                f"{self.batching} batching strategy is not implemented.")
        return {"pixels": batch_pixels, "rays": batch_rays}
示例#4
0
def construct_nerf(key, example_batch, args):
    """Construct a Neural Radiance Field.

  Args:
    key: jnp.ndarray. Random number generator.
    example_batch: dict, an example of a batch of data.
    args: FLAGS class. Hyperparameters of nerf.

  Returns:
    model: nn.Model. Nerf model with parameters.
    state: flax.Module.state. Nerf model state for stateful parameters.
  """
    net_activation = nn.relu
    rgb_activation = nn.sigmoid
    sigma_activation = nn.relu

    # Assert that rgb_activation always produces outputs in [0, 1], and
    # sigma_activation always produce non-negative outputs.
    x = jnp.exp(jnp.linspace(-90, 90, 1024))
    x = jnp.concatenate([-x[::-1], x], 0)

    rgb = rgb_activation(x)
    if jnp.any(rgb < 0) or jnp.any(rgb > 1):
        raise NotImplementedError(
            "Choice of rgb_activation `{}` produces colors outside of [0, 1]".
            format(args.rgb_activation))

    sigma = sigma_activation(x)
    if jnp.any(sigma < 0):
        raise NotImplementedError(
            "Choice of sigma_activation `{}` produces negative densities".
            format(args.sigma_activation))

    model = NerfModel(min_deg_point=args.min_deg_point,
                      max_deg_point=args.max_deg_point,
                      deg_view=args.deg_view,
                      num_coarse_samples=args.num_coarse_samples,
                      num_fine_samples=args.num_fine_samples,
                      use_viewdirs=args.use_viewdirs,
                      near=args.near,
                      far=args.far,
                      noise_std=args.noise_std,
                      white_bkgd=args.white_bkgd,
                      net_depth=args.net_depth,
                      net_width=args.net_width,
                      num_viewdir_channels=args.num_viewdir_channels,
                      viewdir_net_depth=args.viewdir_net_depth,
                      viewdir_net_width=args.viewdir_net_width,
                      skip_layer=args.skip_layer,
                      num_rgb_channels=args.num_rgb_channels,
                      num_sigma_channels=args.num_sigma_channels,
                      lindisp=args.lindisp,
                      net_activation=net_activation,
                      rgb_activation=rgb_activation,
                      sigma_activation=sigma_activation,
                      legacy_posenc_order=args.legacy_posenc_order)
    rays = example_batch["rays"]
    key1, key2, key3 = random.split(key, num=3)

    init_variables = model.init(key1,
                                rng_0=key2,
                                rng_1=key3,
                                rays=utils.namedtuple_map(
                                    lambda x: x[0], rays),
                                randomized=args.randomized)

    return model, init_variables