def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, min_freq_log2=self.min_freq_log2, max_freq_log2=self.max_freq_log2, use_identity=self.use_identity_map) if self.metadata_encoder_type == 'glo': self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) elif self.metadata_encoder_type == 'time': self.metadata_encoder = modules.TimeEncoder( num_freqs=self.metadata_encoder_num_freqs, features=self.num_embedding_features) else: raise ValueError( f'Unknown metadata encoder type {self.metadata_encoder_type}') self.trunk = modules.MLP(depth=self.trunk_depth, width=self.trunk_width, hidden_activation=self.activation, hidden_init=self.default_init, skips=self.skips) branches = { 'w': modules.MLP(depth=self.rotation_depth, width=self.rotation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.rotation_init, output_channels=3), 'v': modules.MLP(depth=self.pivot_depth, width=self.pivot_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.pivot_init, output_channels=3), } if self.use_pivot: branches['p'] = modules.MLP(depth=self.pivot_depth, width=self.pivot_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.pivot_init, output_channels=3) if self.use_translation: branches['t'] = modules.MLP(depth=self.translation_depth, width=self.translation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.translation_init, output_channels=3) # Note that this must be done this way instead of using mutable operations. # See https://github.com/google/flax/issues/524. self.branches = branches
def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2) self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) self.trunk = MLP( depth=self.trunk_depth, width=self.trunk_width, hidden_activation=self.activation, hidden_init=self.default_init, skips=self.skips, ) branches = { "rotation": MLP( depth=self.rotation_depth, width=self.rotation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.rotation_init, output_channels=3, ) } if self.use_pivot: branches["pivot"] = MLP( depth=self.pivot_depth, width=self.pivot_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.pivot_init, output_channels=3, ) if self.use_translation: branches["translation"] = MLP( depth=self.translation_depth, width=self.translation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.translation_init, output_channels=3, ) # Note that this must be done this way instead of using mutable operations. # See https://github.com/google/flax/issues/524. self.branches = branches
def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2) self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) # Note that this must be done this way instead of using mutable list # operations. # See https://github.com/google/flax/issues/524. # pylint: disable=g-complex-comprehension output_dims = 3 self.mlp = MLP( width=self.hidden_channels, depth=self.depth, skips=self.skips, hidden_init=self.hidden_init, output_init=self.output_init, output_channels=output_dims)
def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, min_freq_log2=self.min_freq_log2, max_freq_log2=self.max_freq_log2, use_identity=self.use_identity_map) if self.metadata_encoder_type == 'glo': self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) elif self.metadata_encoder_type == 'time': self.metadata_encoder = modules.TimeEncoder( num_freqs=self.metadata_encoder_num_freqs, features=self.num_embedding_features) elif self.metadata_encoder_type == 'blend': self.glo_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) self.time_encoder = modules.TimeEncoder( num_freqs=self.metadata_encoder_num_freqs, features=self.num_embedding_features) else: raise ValueError( f'Unknown metadata encoder type {self.metadata_encoder_type}') # Note that this must be done this way instead of using mutable list # operations. # See https://github.com/google/flax/issues/524. # pylint: disable=g-complex-comprehension output_dims = 3 self.mlp = modules.MLP(width=self.hidden_channels, depth=self.depth, skips=self.skips, hidden_init=self.hidden_init, output_init=self.output_init, output_channels=output_dims)