Пример #1
0
    def setup(self):
        self.point_encoder = model_utils.vmap_module(
            modules.SinusoidalEncoder,
            num_batch_dims=2)(num_freqs=self.num_nerf_point_freqs)
        self.viewdir_encoder = model_utils.vmap_module(
            modules.SinusoidalEncoder,
            num_batch_dims=1)(num_freqs=self.num_nerf_viewdir_freqs)
        if self.use_appearance_metadata:
            self.appearance_encoder = glo.GloEncoder(
                num_embeddings=self.num_appearance_embeddings,
                features=self.num_appearance_features,
            )
        if self.use_camera_metadata:
            self.camera_encoder = glo.GloEncoder(
                num_embeddings=self.num_camera_embeddings,
                features=self.num_camera_features,
            )

        self.nerf_coarse = modules.NerfMLP(
            nerf_trunk_depth=self.nerf_trunk_depth,
            nerf_trunk_width=self.nerf_trunk_width,
            nerf_condition_depth=self.nerf_condition_depth,
            nerf_condition_width=self.nerf_condition_width,
            activation=self.activation,
            skips=self.nerf_skips,
            alpha_channels=self.alpha_channels,
            rgb_channels=self.rgb_channels,
        )
        if self.num_fine_samples > 0:
            self.nerf_fine = modules.NerfMLP(
                nerf_trunk_depth=self.nerf_trunk_depth,
                nerf_trunk_width=self.nerf_trunk_width,
                nerf_condition_depth=self.nerf_condition_depth,
                nerf_condition_width=self.nerf_condition_width,
                activation=self.activation,
                skips=self.nerf_skips,
                alpha_channels=self.alpha_channels,
                rgb_channels=self.rgb_channels,
            )
        else:
            self.nerf_fine = None

        if self.use_warp:
            self.warp_field = warping.create_warp_field(
                field_type=self.warp_field_type,
                num_freqs=self.num_warp_freqs,
                num_embeddings=self.num_warp_embeddings,
                num_features=self.num_warp_features,
                num_batch_dims=2,
                **self.warp_kwargs)
Пример #2
0
    def setup(self):
        self.points_encoder = modules.AnnealedSinusoidalEncoder(
            num_freqs=self.num_freqs,
            min_freq_log2=self.min_freq_log2,
            max_freq_log2=self.max_freq_log2,
            use_identity=self.use_identity_map)

        if self.metadata_encoder_type == 'glo':
            self.metadata_encoder = glo.GloEncoder(
                num_embeddings=self.num_embeddings,
                features=self.num_embedding_features)
        elif self.metadata_encoder_type == 'time':
            self.metadata_encoder = modules.TimeEncoder(
                num_freqs=self.metadata_encoder_num_freqs,
                features=self.num_embedding_features)
        else:
            raise ValueError(
                f'Unknown metadata encoder type {self.metadata_encoder_type}')

        self.trunk = modules.MLP(depth=self.trunk_depth,
                                 width=self.trunk_width,
                                 hidden_activation=self.activation,
                                 hidden_init=self.default_init,
                                 skips=self.skips)

        branches = {
            'w':
            modules.MLP(depth=self.rotation_depth,
                        width=self.rotation_width,
                        hidden_activation=self.activation,
                        hidden_init=self.default_init,
                        output_init=self.rotation_init,
                        output_channels=3),
            'v':
            modules.MLP(depth=self.pivot_depth,
                        width=self.pivot_width,
                        hidden_activation=self.activation,
                        hidden_init=self.default_init,
                        output_init=self.pivot_init,
                        output_channels=3),
        }
        if self.use_pivot:
            branches['p'] = modules.MLP(depth=self.pivot_depth,
                                        width=self.pivot_width,
                                        hidden_activation=self.activation,
                                        hidden_init=self.default_init,
                                        output_init=self.pivot_init,
                                        output_channels=3)
        if self.use_translation:
            branches['t'] = modules.MLP(depth=self.translation_depth,
                                        width=self.translation_width,
                                        hidden_activation=self.activation,
                                        hidden_init=self.default_init,
                                        output_init=self.translation_init,
                                        output_channels=3)
        # Note that this must be done this way instead of using mutable operations.
        # See https://github.com/google/flax/issues/524.
        self.branches = branches
Пример #3
0
  def setup(self):
    if self.use_warp:
      self.warp_field = self.create_warp_field(self, num_batch_dims=2)

    self.point_encoder = model_utils.vmap_module(
        modules.SinusoidalEncoder, num_batch_dims=2)(
            num_freqs=self.num_nerf_point_freqs)
    self.viewdir_encoder = model_utils.vmap_module(
        modules.SinusoidalEncoder, num_batch_dims=1)(
            num_freqs=self.num_nerf_viewdir_freqs)
    if self.use_appearance_metadata:
      self.appearance_encoder = glo.GloEncoder(
          num_embeddings=self.num_appearance_embeddings,
          features=self.num_appearance_features)
    if self.use_camera_metadata:
      self.camera_encoder = glo.GloEncoder(
          num_embeddings=self.num_camera_embeddings,
          features=self.num_camera_features)

    nerf_mlps = {
        'coarse': modules.NerfMLP(
            trunk_depth=self.nerf_trunk_depth,
            trunk_width=self.nerf_trunk_width,
            rgb_branch_depth=self.nerf_rgb_branch_depth,
            rgb_branch_width=self.nerf_rgb_branch_width,
            activation=self.activation,
            skips=self.nerf_skips,
            alpha_channels=self.alpha_channels,
            rgb_channels=self.rgb_channels)
    }
    if self.num_fine_samples > 0:
      nerf_mlps['fine'] = modules.NerfMLP(
          trunk_depth=self.nerf_trunk_depth,
          trunk_width=self.nerf_trunk_width,
          rgb_branch_depth=self.nerf_rgb_branch_depth,
          rgb_branch_width=self.nerf_rgb_branch_width,
          activation=self.activation,
          skips=self.nerf_skips,
          alpha_channels=self.alpha_channels,
          rgb_channels=self.rgb_channels)
    self.nerf_mlps = nerf_mlps
Пример #4
0
    def setup(self):
        self.points_encoder = modules.AnnealedSinusoidalEncoder(
            num_freqs=self.num_freqs,
            min_freq_log2=self.min_freq_log2,
            max_freq_log2=self.max_freq_log2,
            use_identity=self.use_identity_map)

        if self.metadata_encoder_type == 'glo':
            self.metadata_encoder = glo.GloEncoder(
                num_embeddings=self.num_embeddings,
                features=self.num_embedding_features)
        elif self.metadata_encoder_type == 'time':
            self.metadata_encoder = modules.TimeEncoder(
                num_freqs=self.metadata_encoder_num_freqs,
                features=self.num_embedding_features)
        elif self.metadata_encoder_type == 'blend':
            self.glo_encoder = glo.GloEncoder(
                num_embeddings=self.num_embeddings,
                features=self.num_embedding_features)
            self.time_encoder = modules.TimeEncoder(
                num_freqs=self.metadata_encoder_num_freqs,
                features=self.num_embedding_features)
        else:
            raise ValueError(
                f'Unknown metadata encoder type {self.metadata_encoder_type}')

        # Note that this must be done this way instead of using mutable list
        # operations.
        # See https://github.com/google/flax/issues/524.
        # pylint: disable=g-complex-comprehension
        output_dims = 3
        self.mlp = modules.MLP(width=self.hidden_channels,
                               depth=self.depth,
                               skips=self.skips,
                               hidden_init=self.hidden_init,
                               output_init=self.output_init,
                               output_channels=output_dims)
Пример #5
0
    def setup(self):
        self.points_encoder = modules.AnnealedSinusoidalEncoder(
            num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2)
        self.metadata_encoder = glo.GloEncoder(
            num_embeddings=self.num_embeddings,
            features=self.num_embedding_features)
        self.trunk = MLP(
            depth=self.trunk_depth,
            width=self.trunk_width,
            hidden_activation=self.activation,
            hidden_init=self.default_init,
            skips=self.skips,
        )

        branches = {
            "rotation":
            MLP(
                depth=self.rotation_depth,
                width=self.rotation_width,
                hidden_activation=self.activation,
                hidden_init=self.default_init,
                output_init=self.rotation_init,
                output_channels=3,
            )
        }
        if self.use_pivot:
            branches["pivot"] = MLP(
                depth=self.pivot_depth,
                width=self.pivot_width,
                hidden_activation=self.activation,
                hidden_init=self.default_init,
                output_init=self.pivot_init,
                output_channels=3,
            )
        if self.use_translation:
            branches["translation"] = MLP(
                depth=self.translation_depth,
                width=self.translation_width,
                hidden_activation=self.activation,
                hidden_init=self.default_init,
                output_init=self.translation_init,
                output_channels=3,
            )

        # Note that this must be done this way instead of using mutable operations.
        # See https://github.com/google/flax/issues/524.
        self.branches = branches
Пример #6
0
 def setup(self):
   self.points_encoder = modules.AnnealedSinusoidalEncoder(
       num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2)
   self.metadata_encoder = glo.GloEncoder(
       num_embeddings=self.num_embeddings,
       features=self.num_embedding_features)
   # Note that this must be done this way instead of using mutable list
   # operations.
   # See https://github.com/google/flax/issues/524.
   # pylint: disable=g-complex-comprehension
   output_dims = 3
   self.mlp = MLP(
       width=self.hidden_channels,
       depth=self.depth,
       skips=self.skips,
       hidden_init=self.hidden_init,
       output_init=self.output_init,
       output_channels=output_dims)