def __init__(self, config: Config): super(SimpleINR, self).__init__() layers = [ nn.Linear(config.in_features, config.layer_sizes[0], bias=config.has_bias), create_activation(config.activation) ] for index in range(len(config.layer_sizes) - 1): layers.extend([ nn.Linear(config.layer_sizes[index], config.layer_sizes[index + 1], bias=config.has_bias), create_activation(config.activation) ]) layers.extend([ nn.Linear(config.layer_sizes[-1], config.out_features, bias=config.has_bias), ]) self.model = nn.Sequential(*layers)
def __init__(self, config: Config): super(FourierINR, self).__init__() layers = [ nn.Linear(config.num_fourier_feats * 2, config.layer_sizes[0], bias=config.has_bias), create_activation(config.activation) ] for index in range(len(config.layer_sizes) - 1): transform = nn.Sequential( nn.Linear(config.layer_sizes[index], config.layer_sizes[index + 1], bias=config.has_bias), create_activation(config.activation)) if config.residual.enabled: layers.append(LinearResidual(config.residual, transform)) else: layers.append(transform) layers.append( nn.Linear(config.layer_sizes[-1], config.out_features, bias=config.has_bias)) self.model = nn.Sequential(*layers) # Initializing the basis basis_matrix = config.scale * torch.randn(config.num_fourier_feats, config.in_features) self.basis_matrix = nn.Parameter(basis_matrix, requires_grad=config.learnable_basis)
def init_model(self): layers = self.create_transform( self.config.hp.inr.coord_dim, self.config.hp.inr.layer_sizes[0], layer_type=self.config.hp.inr. coords_layer_type, # First layer is of full control is_coord_layer=True) layers.append(INRProxy(self.create_sine( self.config.hp.inr.w0_initial))) for i in range(len(self.config.hp.inr.layer_sizes) - 1): layers.extend( self.create_transform( self.config.hp.inr.layer_sizes[i], self.config.hp.inr.layer_sizes[i + 1], layer_type=self.config.hp.inr.hid_layer_type) ) # Middle layers are large so they are controlled via AdaIN layers.append(INRProxy(self.create_sine(self.config.hp.inr.w0))) layers.extend( self.create_transform( self.config.hp.inr.layer_sizes[-1], self.config.data.num_img_channels, layer_type= 'linear' # The last layer is small so let's also control it fully )) layers.append( INRProxy(create_activation(self.config.hp.inr.output_activation))) self.model = nn.Sequential(*layers)
def init_model(self): layers = self.create_transform( self.config.hp.inr.coord_dim, self.config.hp.inr.layer_sizes[0] // 2, layer_type=self.config.hp.inr.coords_layer_type, is_coord_layer=True) layers.append(INRProxy(create_activation('sines_cosines'))) hid_layers = [] for i in range(len(self.config.hp.inr.layer_sizes) - 1): if self.config.hp.inr.get('skip_coords') and i > 0: input_dim = self.config.hp.inr.layer_sizes[ i] + self.config.hp.inr.layer_sizes[0] else: input_dim = self.config.hp.inr.layer_sizes[i] curr_transform_layers = self.create_transform( input_dim, self.config.hp.inr.layer_sizes[i + 1], layer_type=self.config.hp.inr.hid_layer_type) curr_transform_layers.append( INRProxy(create_activation(self.config.hp.inr.activation))) if self.config.hp.inr.residual: hid_layers.append( INRResidual(INRSequential(*curr_transform_layers))) else: hid_layers.extend(curr_transform_layers) if self.config.hp.inr.get('skip_coords'): layers.append(INRInputSkip(*hid_layers)) else: layers.extend(hid_layers) layers.extend( self.create_transform(self.config.hp.inr.layer_sizes[-1], self.config.data.num_img_channels, 'linear')) layers.append( INRProxy(create_activation(self.config.hp.inr.output_activation))) self.model = nn.Sequential(*layers)
def __init__(self, in_features: int, out_features: int, config: Config, is_first_layer: bool): super().__init__() if config.equalized_lr: layers = [EqualLinear(in_features, out_features)] else: layers = [nn.Linear(in_features, out_features)] if in_features == out_features and config.residual and not is_first_layer: self.residual = True self.main_branch_weight = nn.Parameter( torch.tensor(config.main_branch_weight)) else: self.residual = False if config.has_bn: layers.append(nn.BatchNorm1d(out_features)) layers.append( create_activation(config.activation, **config.activation_kwargs)) self.transform = nn.Sequential(*layers)
def init_model(self): blocks = [] if self.config.hp.inr.res_increase_scheme.enabled: res_configs = [ self.create_res_config(i) for i in range(self.config.hp.inr.num_blocks) ] else: resolutions = self.generate_img_sizes( self.config.data.target_img_size) res_configs = [ self.config.hp.inr.resolutions_params[resolutions[i]] for i in range(self.config.hp.inr.num_blocks) ] print('resolution:', [c.resolution for c in res_configs]) print('dim:', [c.dim for c in res_configs]) print('num_learnable_coord_feats:', [c.num_learnable_coord_feats for c in res_configs]) print('to_rgb:', [c.to_rgb for c in res_configs]) num_to_rgb_blocks = sum(c.to_rgb for c in res_configs) for i, res_config in enumerate(res_configs): # 1. Creating coord fourier feat embedders for each resolution coord_embedder = INRSequential( INRFourierFeats(res_config), INRProxy(create_activation('sines_cosines'))) coord_feat_dim = coord_embedder[0].get_num_feats() * 2 # 2. Main branch. First need does not need any wiring, but later layers use it. # A good thing is that we do not need skip-coords anymore. if i > 0: # Different-resolution blocks are wired together with the connector connector_layers = [ INRResConnector( res_configs[i - 1].dim, coord_feat_dim, res_config.dim, self.config.hp.inr.upsampling_mode, **self.config.hp.inr.module_kwargs.se_factorized), ] if self.config.hp.inr.use_pixel_norm: connector_layers.append(INRPixelNorm()) connector_layers.append( INRProxy( create_activation( self.config.hp.inr.activation, **self.config.hp.inr.activation_kwargs))) connector = INRSequential(*connector_layers) else: connector = INRIdentity() transform_layers = [] for j in range(res_config.n_layers): if i == 0 and j == 0: input_size = coord_feat_dim # Since we do not have previous feat dims elif self.config.hp.inr.skip_coords: input_size = coord_feat_dim + res_config.dim else: input_size = res_config.dim transform_layers.extend( self.create_transform( input_size, res_config.dim, layer_type=self.config.hp.inr.hid_layer_type)) if self.config.hp.inr.use_pixel_norm: transform_layers.append(INRPixelNorm()) if self.config.hp.inr.use_noise: transform_layers.append(INRNoiseInjection()) transform_layers.append( INRProxy( create_activation( self.config.hp.inr.activation, **self.config.hp.inr.activation_kwargs))) if res_config.to_rgb or i == (self.config.hp.inr.num_blocks - 1): to_rgb_weight_std = self.compute_weight_std( res_config.dim, is_coord_layer=False) to_rgb_bias_std = self.compute_bias_std(res_config.dim, is_coord_layer=False) if self.config.hp.inr.additionaly_scale_to_rgb: to_rgb_weight_std /= np.sqrt(num_to_rgb_blocks) to_rgb_bias_std /= np.sqrt(num_to_rgb_blocks) to_rgb = INRToRGB(res_config.dim, self.config.hp.inr.to_rgb_activation, self.config.hp.inr.upsampling_mode, to_rgb_weight_std, to_rgb_bias_std) else: to_rgb = INRIdentity() if self.config.hp.inr.skip_coords: transform = INRCoordsSkip(*transform_layers, concat_to_the_first=i > 0) else: transform = INRSequential(*transform_layers) blocks.append( INRModuleDict({ 'coord_embedder': coord_embedder, 'transform': transform, 'connector': connector, 'to_rgb': to_rgb, })) self.model = INRModuleDict({f'b_{i}': b for i, b in enumerate(blocks)})