def setup(self): self.point_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=2)(num_freqs=self.num_nerf_point_freqs) self.viewdir_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=1)(num_freqs=self.num_nerf_viewdir_freqs) if self.use_appearance_metadata: self.appearance_encoder = glo.GloEncoder( num_embeddings=self.num_appearance_embeddings, features=self.num_appearance_features, ) if self.use_camera_metadata: self.camera_encoder = glo.GloEncoder( num_embeddings=self.num_camera_embeddings, features=self.num_camera_features, ) self.nerf_coarse = modules.NerfMLP( nerf_trunk_depth=self.nerf_trunk_depth, nerf_trunk_width=self.nerf_trunk_width, nerf_condition_depth=self.nerf_condition_depth, nerf_condition_width=self.nerf_condition_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels, ) if self.num_fine_samples > 0: self.nerf_fine = modules.NerfMLP( nerf_trunk_depth=self.nerf_trunk_depth, nerf_trunk_width=self.nerf_trunk_width, nerf_condition_depth=self.nerf_condition_depth, nerf_condition_width=self.nerf_condition_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels, ) else: self.nerf_fine = None if self.use_warp: self.warp_field = warping.create_warp_field( field_type=self.warp_field_type, num_freqs=self.num_warp_freqs, num_embeddings=self.num_warp_embeddings, num_features=self.num_warp_features, num_batch_dims=2, **self.warp_kwargs)
def create_warp_field( field_type: str, num_freqs: int, num_embeddings: int, num_features: int, num_batch_dims: int, **kwargs): """Factory function for warp fields.""" if field_type == 'translation': warp_field_cls = TranslationField elif field_type == 'se3': warp_field_cls = SE3Field else: raise ValueError(f'Unknown warp field type: {field_type!r}') v_warp_field_cls = model_utils.vmap_module( warp_field_cls, num_batch_dims=num_batch_dims, # (points, metadata, alpha, return_jacobian, metadata_encoded). in_axes=(0, 0, None, None, None)) return v_warp_field_cls( num_freqs=num_freqs, num_embeddings=num_embeddings, num_embedding_features=num_features, **kwargs)
def setup(self): if self.use_warp: self.warp_field = self.create_warp_field(self, num_batch_dims=2) self.point_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=2)( num_freqs=self.num_nerf_point_freqs) self.viewdir_encoder = model_utils.vmap_module( modules.SinusoidalEncoder, num_batch_dims=1)( num_freqs=self.num_nerf_viewdir_freqs) if self.use_appearance_metadata: self.appearance_encoder = glo.GloEncoder( num_embeddings=self.num_appearance_embeddings, features=self.num_appearance_features) if self.use_camera_metadata: self.camera_encoder = glo.GloEncoder( num_embeddings=self.num_camera_embeddings, features=self.num_camera_features) nerf_mlps = { 'coarse': modules.NerfMLP( trunk_depth=self.nerf_trunk_depth, trunk_width=self.nerf_trunk_width, rgb_branch_depth=self.nerf_rgb_branch_depth, rgb_branch_width=self.nerf_rgb_branch_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels) } if self.num_fine_samples > 0: nerf_mlps['fine'] = modules.NerfMLP( trunk_depth=self.nerf_trunk_depth, trunk_width=self.nerf_trunk_width, rgb_branch_depth=self.nerf_rgb_branch_depth, rgb_branch_width=self.nerf_rgb_branch_width, activation=self.activation, skips=self.nerf_skips, alpha_channels=self.alpha_channels, rgb_channels=self.rgb_channels) self.nerf_mlps = nerf_mlps