Exemple #1
0
    def setup(self):
        self.point_encoder = model_utils.vmap_module(
            modules.SinusoidalEncoder,
            num_batch_dims=2)(num_freqs=self.num_nerf_point_freqs)
        self.viewdir_encoder = model_utils.vmap_module(
            modules.SinusoidalEncoder,
            num_batch_dims=1)(num_freqs=self.num_nerf_viewdir_freqs)
        if self.use_appearance_metadata:
            self.appearance_encoder = glo.GloEncoder(
                num_embeddings=self.num_appearance_embeddings,
                features=self.num_appearance_features,
            )
        if self.use_camera_metadata:
            self.camera_encoder = glo.GloEncoder(
                num_embeddings=self.num_camera_embeddings,
                features=self.num_camera_features,
            )

        self.nerf_coarse = modules.NerfMLP(
            nerf_trunk_depth=self.nerf_trunk_depth,
            nerf_trunk_width=self.nerf_trunk_width,
            nerf_condition_depth=self.nerf_condition_depth,
            nerf_condition_width=self.nerf_condition_width,
            activation=self.activation,
            skips=self.nerf_skips,
            alpha_channels=self.alpha_channels,
            rgb_channels=self.rgb_channels,
        )
        if self.num_fine_samples > 0:
            self.nerf_fine = modules.NerfMLP(
                nerf_trunk_depth=self.nerf_trunk_depth,
                nerf_trunk_width=self.nerf_trunk_width,
                nerf_condition_depth=self.nerf_condition_depth,
                nerf_condition_width=self.nerf_condition_width,
                activation=self.activation,
                skips=self.nerf_skips,
                alpha_channels=self.alpha_channels,
                rgb_channels=self.rgb_channels,
            )
        else:
            self.nerf_fine = None

        if self.use_warp:
            self.warp_field = warping.create_warp_field(
                field_type=self.warp_field_type,
                num_freqs=self.num_warp_freqs,
                num_embeddings=self.num_warp_embeddings,
                num_features=self.num_warp_features,
                num_batch_dims=2,
                **self.warp_kwargs)
Exemple #2
0
def create_warp_field(
    field_type: str,
    num_freqs: int,
    num_embeddings: int,
    num_features: int,
    num_batch_dims: int,
    **kwargs):
  """Factory function for warp fields."""
  if field_type == 'translation':
    warp_field_cls = TranslationField
  elif field_type == 'se3':
    warp_field_cls = SE3Field
  else:
    raise ValueError(f'Unknown warp field type: {field_type!r}')

  v_warp_field_cls = model_utils.vmap_module(
      warp_field_cls,
      num_batch_dims=num_batch_dims,
      # (points, metadata, alpha, return_jacobian, metadata_encoded).
      in_axes=(0, 0, None, None, None))

  return v_warp_field_cls(
      num_freqs=num_freqs,
      num_embeddings=num_embeddings,
      num_embedding_features=num_features,
      **kwargs)
Exemple #3
0
  def setup(self):
    if self.use_warp:
      self.warp_field = self.create_warp_field(self, num_batch_dims=2)

    self.point_encoder = model_utils.vmap_module(
        modules.SinusoidalEncoder, num_batch_dims=2)(
            num_freqs=self.num_nerf_point_freqs)
    self.viewdir_encoder = model_utils.vmap_module(
        modules.SinusoidalEncoder, num_batch_dims=1)(
            num_freqs=self.num_nerf_viewdir_freqs)
    if self.use_appearance_metadata:
      self.appearance_encoder = glo.GloEncoder(
          num_embeddings=self.num_appearance_embeddings,
          features=self.num_appearance_features)
    if self.use_camera_metadata:
      self.camera_encoder = glo.GloEncoder(
          num_embeddings=self.num_camera_embeddings,
          features=self.num_camera_features)

    nerf_mlps = {
        'coarse': modules.NerfMLP(
            trunk_depth=self.nerf_trunk_depth,
            trunk_width=self.nerf_trunk_width,
            rgb_branch_depth=self.nerf_rgb_branch_depth,
            rgb_branch_width=self.nerf_rgb_branch_width,
            activation=self.activation,
            skips=self.nerf_skips,
            alpha_channels=self.alpha_channels,
            rgb_channels=self.rgb_channels)
    }
    if self.num_fine_samples > 0:
      nerf_mlps['fine'] = modules.NerfMLP(
          trunk_depth=self.nerf_trunk_depth,
          trunk_width=self.nerf_trunk_width,
          rgb_branch_depth=self.nerf_rgb_branch_depth,
          rgb_branch_width=self.nerf_rgb_branch_width,
          activation=self.activation,
          skips=self.nerf_skips,
          alpha_channels=self.alpha_channels,
          rgb_channels=self.rgb_channels)
    self.nerf_mlps = nerf_mlps