def setup(self): self.point_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=2)(num_freqs=self.num_nerf_point_freqs) self.viewdir_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=1)(num_freqs=self.num_nerf_viewdir_freqs) if self.use_appearance_metadata: self.appearance_encoder = glo.GloEncoder( num_embeddings=self.num_appearance_embeddings, features=self.num_appearance_features, ) if self.use_camera_metadata: self.camera_encoder = glo.GloEncoder( num_embeddings=self.num_camera_embeddings, features=self.num_camera_features, ) self.nerf_coarse = modules.NerfMLP( nerf_trunk_depth=self.nerf_trunk_depth, nerf_trunk_width=self.nerf_trunk_width, nerf_condition_depth=self.nerf_condition_depth, nerf_condition_width=self.nerf_condition_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels, ) if self.num_fine_samples > 0: self.nerf_fine = modules.NerfMLP( nerf_trunk_depth=self.nerf_trunk_depth, nerf_trunk_width=self.nerf_trunk_width, nerf_condition_depth=self.nerf_condition_depth, nerf_condition_width=self.nerf_condition_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels, ) else: self.nerf_fine = None if self.use_warp: self.warp_field = warping.create_warp_field( field_type=self.warp_field_type, num_freqs=self.num_warp_freqs, num_embeddings=self.num_warp_embeddings, num_features=self.num_warp_features, num_batch_dims=2, **self.warp_kwargs)
def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, min_freq_log2=self.min_freq_log2, max_freq_log2=self.max_freq_log2, use_identity=self.use_identity_map) if self.metadata_encoder_type == 'glo': self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) elif self.metadata_encoder_type == 'time': self.metadata_encoder = modules.TimeEncoder( num_freqs=self.metadata_encoder_num_freqs, features=self.num_embedding_features) else: raise ValueError( f'Unknown metadata encoder type {self.metadata_encoder_type}') self.trunk = modules.MLP(depth=self.trunk_depth, width=self.trunk_width, hidden_activation=self.activation, hidden_init=self.default_init, skips=self.skips) branches = { 'w': modules.MLP(depth=self.rotation_depth, width=self.rotation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.rotation_init, output_channels=3), 'v': modules.MLP(depth=self.pivot_depth, width=self.pivot_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.pivot_init, output_channels=3), } if self.use_pivot: branches['p'] = modules.MLP(depth=self.pivot_depth, width=self.pivot_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.pivot_init, output_channels=3) if self.use_translation: branches['t'] = modules.MLP(depth=self.translation_depth, width=self.translation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.translation_init, output_channels=3) # Note that this must be done this way instead of using mutable operations. # See https://github.com/google/flax/issues/524. self.branches = branches
def setup(self): if self.use_warp: self.warp_field = self.create_warp_field(self, num_batch_dims=2) self.point_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=2)( num_freqs=self.num_nerf_point_freqs) self.viewdir_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=1)( num_freqs=self.num_nerf_viewdir_freqs) if self.use_appearance_metadata: self.appearance_encoder = glo.GloEncoder( num_embeddings=self.num_appearance_embeddings, features=self.num_appearance_features) if self.use_camera_metadata: self.camera_encoder = glo.GloEncoder( num_embeddings=self.num_camera_embeddings, features=self.num_camera_features) nerf_mlps = { 'coarse': modules.NerfMLP( trunk_depth=self.nerf_trunk_depth, trunk_width=self.nerf_trunk_width, rgb_branch_depth=self.nerf_rgb_branch_depth, rgb_branch_width=self.nerf_rgb_branch_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels) } if self.num_fine_samples > 0: nerf_mlps['fine'] = modules.NerfMLP( trunk_depth=self.nerf_trunk_depth, trunk_width=self.nerf_trunk_width, rgb_branch_depth=self.nerf_rgb_branch_depth, rgb_branch_width=self.nerf_rgb_branch_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels) self.nerf_mlps = nerf_mlps
def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, min_freq_log2=self.min_freq_log2, max_freq_log2=self.max_freq_log2, use_identity=self.use_identity_map) if self.metadata_encoder_type == 'glo': self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) elif self.metadata_encoder_type == 'time': self.metadata_encoder = modules.TimeEncoder( num_freqs=self.metadata_encoder_num_freqs, features=self.num_embedding_features) elif self.metadata_encoder_type == 'blend': self.glo_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) self.time_encoder = modules.TimeEncoder( num_freqs=self.metadata_encoder_num_freqs, features=self.num_embedding_features) else: raise ValueError( f'Unknown metadata encoder type {self.metadata_encoder_type}') # Note that this must be done this way instead of using mutable list # operations. # See https://github.com/google/flax/issues/524. # pylint: disable=g-complex-comprehension output_dims = 3 self.mlp = modules.MLP(width=self.hidden_channels, depth=self.depth, skips=self.skips, hidden_init=self.hidden_init, output_init=self.output_init, output_channels=output_dims)
def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2) self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) self.trunk = MLP( depth=self.trunk_depth, width=self.trunk_width, hidden_activation=self.activation, hidden_init=self.default_init, skips=self.skips, ) branches = { "rotation": MLP( depth=self.rotation_depth, width=self.rotation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.rotation_init, output_channels=3, ) } if self.use_pivot: branches["pivot"] = MLP( depth=self.pivot_depth, width=self.pivot_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.pivot_init, output_channels=3, ) if self.use_translation: branches["translation"] = MLP( depth=self.translation_depth, width=self.translation_width, hidden_activation=self.activation, hidden_init=self.default_init, output_init=self.translation_init, output_channels=3, ) # Note that this must be done this way instead of using mutable operations. # See https://github.com/google/flax/issues/524. self.branches = branches
def setup(self): self.points_encoder = modules.AnnealedSinusoidalEncoder( num_freqs=self.num_freqs, max_freq_log2=self.max_freq_log2) self.metadata_encoder = glo.GloEncoder( num_embeddings=self.num_embeddings, features=self.num_embedding_features) # Note that this must be done this way instead of using mutable list # operations. # See https://github.com/google/flax/issues/524. # pylint: disable=g-complex-comprehension output_dims = 3 self.mlp = MLP( width=self.hidden_channels, depth=self.depth, skips=self.skips, hidden_init=self.hidden_init, output_init=self.output_init, output_channels=output_dims)