Exemplo n.º 1
0
    def setup(self):
        self.points_encoder = modules.AnnealedSinusoidalEncoder(
            num_freqs=self.num_freqs,
            min_freq_log2=self.min_freq_log2,
            max_freq_log2=self.max_freq_log2,
            use_identity=self.use_identity_map)

        if self.metadata_encoder_type == 'glo':
            self.metadata_encoder = glo.GloEncoder(
                num_embeddings=self.num_embeddings,
                features=self.num_embedding_features)
        elif self.metadata_encoder_type == 'time':
            self.metadata_encoder = modules.TimeEncoder(
                num_freqs=self.metadata_encoder_num_freqs,
                features=self.num_embedding_features)
        else:
            raise ValueError(
                f'Unknown metadata encoder type {self.metadata_encoder_type}')

        self.trunk = modules.MLP(depth=self.trunk_depth,
                                 width=self.trunk_width,
                                 hidden_activation=self.activation,
                                 hidden_init=self.default_init,
                                 skips=self.skips)

        branches = {
            'w':
            modules.MLP(depth=self.rotation_depth,
                        width=self.rotation_width,
                        hidden_activation=self.activation,
                        hidden_init=self.default_init,
                        output_init=self.rotation_init,
                        output_channels=3),
            'v':
            modules.MLP(depth=self.pivot_depth,
                        width=self.pivot_width,
                        hidden_activation=self.activation,
                        hidden_init=self.default_init,
                        output_init=self.pivot_init,
                        output_channels=3),
        }
        if self.use_pivot:
            branches['p'] = modules.MLP(depth=self.pivot_depth,
                                        width=self.pivot_width,
                                        hidden_activation=self.activation,
                                        hidden_init=self.default_init,
                                        output_init=self.pivot_init,
                                        output_channels=3)
        if self.use_translation:
            branches['t'] = modules.MLP(depth=self.translation_depth,
                                        width=self.translation_width,
                                        hidden_activation=self.activation,
                                        hidden_init=self.default_init,
                                        output_init=self.translation_init,
                                        output_channels=3)
        # Note that this must be done this way instead of using mutable operations.
        # See https://github.com/google/flax/issues/524.
        self.branches = branches
Exemplo n.º 2
0
    def setup(self):
        self.points_encoder = modules.AnnealedSinusoidalEncoder(
            num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2)
        self.metadata_encoder = glo.GloEncoder(
            num_embeddings=self.num_embeddings,
            features=self.num_embedding_features)
        self.trunk = MLP(
            depth=self.trunk_depth,
            width=self.trunk_width,
            hidden_activation=self.activation,
            hidden_init=self.default_init,
            skips=self.skips,
        )

        branches = {
            "rotation":
            MLP(
                depth=self.rotation_depth,
                width=self.rotation_width,
                hidden_activation=self.activation,
                hidden_init=self.default_init,
                output_init=self.rotation_init,
                output_channels=3,
            )
        }
        if self.use_pivot:
            branches["pivot"] = MLP(
                depth=self.pivot_depth,
                width=self.pivot_width,
                hidden_activation=self.activation,
                hidden_init=self.default_init,
                output_init=self.pivot_init,
                output_channels=3,
            )
        if self.use_translation:
            branches["translation"] = MLP(
                depth=self.translation_depth,
                width=self.translation_width,
                hidden_activation=self.activation,
                hidden_init=self.default_init,
                output_init=self.translation_init,
                output_channels=3,
            )

        # Note that this must be done this way instead of using mutable operations.
        # See https://github.com/google/flax/issues/524.
        self.branches = branches
Exemplo n.º 3
0
 def setup(self):
   self.points_encoder = modules.AnnealedSinusoidalEncoder(
       num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2)
   self.metadata_encoder = glo.GloEncoder(
       num_embeddings=self.num_embeddings,
       features=self.num_embedding_features)
   # Note that this must be done this way instead of using mutable list
   # operations.
   # See https://github.com/google/flax/issues/524.
   # pylint: disable=g-complex-comprehension
   output_dims = 3
   self.mlp = MLP(
       width=self.hidden_channels,
       depth=self.depth,
       skips=self.skips,
       hidden_init=self.hidden_init,
       output_init=self.output_init,
       output_channels=output_dims)
Exemplo n.º 4
0
    def setup(self):
        self.points_encoder = modules.AnnealedSinusoidalEncoder(
            num_freqs=self.num_freqs,
            min_freq_log2=self.min_freq_log2,
            max_freq_log2=self.max_freq_log2,
            use_identity=self.use_identity_map)

        if self.metadata_encoder_type == 'glo':
            self.metadata_encoder = glo.GloEncoder(
                num_embeddings=self.num_embeddings,
                features=self.num_embedding_features)
        elif self.metadata_encoder_type == 'time':
            self.metadata_encoder = modules.TimeEncoder(
                num_freqs=self.metadata_encoder_num_freqs,
                features=self.num_embedding_features)
        elif self.metadata_encoder_type == 'blend':
            self.glo_encoder = glo.GloEncoder(
                num_embeddings=self.num_embeddings,
                features=self.num_embedding_features)
            self.time_encoder = modules.TimeEncoder(
                num_freqs=self.metadata_encoder_num_freqs,
                features=self.num_embedding_features)
        else:
            raise ValueError(
                f'Unknown metadata encoder type {self.metadata_encoder_type}')

        # Note that this must be done this way instead of using mutable list
        # operations.
        # See https://github.com/google/flax/issues/524.
        # pylint: disable=g-complex-comprehension
        output_dims = 3
        self.mlp = modules.MLP(width=self.hidden_channels,
                               depth=self.depth,
                               skips=self.skips,
                               hidden_init=self.hidden_init,
                               output_init=self.output_init,
                               output_channels=output_dims)