def __init__(self, dis_cfg, data_cfg): super().__init__() self.data_cfg = data_cfg num_input_channels = get_paired_input_label_channel_number(data_cfg) if num_input_channels == 0: num_input_channels = getattr(data_cfg, 'label_channels', 1) num_img_channels = get_paired_input_image_channel_number(data_cfg) self.num_frames_D = data_cfg.num_frames_D self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0) num_netD_input_channels = (num_input_channels + num_img_channels) self.use_few_shot = 'few_shot' in data_cfg.type if self.use_few_shot: num_netD_input_channels *= 2 self.net_D = MultiPatchDiscriminator(dis_cfg.image, num_netD_input_channels) self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None) if self.add_dis_cfg is not None: for name in self.add_dis_cfg: add_dis_cfg = self.add_dis_cfg[name] num_ch = num_img_channels * (2 if self.use_few_shot else 1) setattr(self, 'net_D_' + name, MultiPatchDiscriminator(add_dis_cfg, num_ch)) # Temporal discriminator. self.num_netDT_input_channels = num_img_channels * self.num_frames_D for n in range(self.num_scales): setattr(self, 'net_DT%d' % n, MultiPatchDiscriminator(dis_cfg.temporal, self.num_netDT_input_channels)) self.has_fg = getattr(data_cfg, 'has_foreground', False)
def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode, base_conv_block, base_res_block): super(GlobalGenerator, self).__init__() num_img_channels = get_paired_input_image_channel_number(data_cfg) num_filters = getattr(gen_cfg, 'num_filters', 64) num_downsamples = getattr(gen_cfg, 'num_downsamples', 4) num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9) # First layer. model = [base_conv_block(num_input_channels, num_filters, kernel_size=7, padding=3)] # Downsample. for i in range(num_downsamples): ch = num_filters * (2 ** i) model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] # ResNet blocks. ch = num_filters * (2 ** num_downsamples) for i in range(num_res_blocks): model += [base_res_block(ch, ch, 3, padding=1)] # Upsample. num_upsamples = num_downsamples for i in reversed(range(num_upsamples)): ch = num_filters * (2 ** i) model += \ [NearestUpsample(scale_factor=2), base_conv_block(ch * 2, ch, 3, padding=1)] model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3, padding_mode=padding_mode, nonlinearity='tanh')] self.model = nn.Sequential(*model)
def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters, padding_mode, base_conv_block, base_res_block, output_img=False): super(LocalEnhancer, self).__init__() num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3) num_img_channels = get_paired_input_image_channel_number(data_cfg) # Downsample. model_downsample = \ [base_conv_block(num_input_channels, num_filters, 7, padding=3), base_conv_block(num_filters, num_filters * 2, 3, stride=2, padding=1)] # Residual blocks. model_upsample = [] for i in range(num_res_blocks): model_upsample += [base_res_block(num_filters * 2, num_filters * 2, 3, padding=1)] # Upsample. model_upsample += \ [NearestUpsample(scale_factor=2), base_conv_block(num_filters * 2, num_filters, 3, padding=1)] # Final convolution. if output_img: model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3, padding_mode=padding_mode, nonlinearity='tanh')] self.model_downsample = nn.Sequential(*model_downsample) self.model_upsample = nn.Sequential(*model_upsample)
def __init__(self, dis_cfg, data_cfg): super(Discriminator, self).__init__() print('Multi-resolution patch discriminator initialization.') # We assume the first datum is the ground truth image. image_channels = get_paired_input_image_channel_number(data_cfg) # Calculate number of channels in the input label. num_labels = get_paired_input_label_channel_number(data_cfg) # Build the discriminator. kernel_size = getattr(dis_cfg, 'kernel_size', 3) num_filters = getattr(dis_cfg, 'num_filters', 128) max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) num_layers = getattr(dis_cfg, 'num_layers', 5) activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') print('\tBase filter number: %d' % num_filters) print('\tNumber of discriminators: %d' % num_discriminators) print('\tNumber of layers in a discriminator: %d' % num_layers) print('\tWeight norm type: %s' % weight_norm_type) num_input_channels = image_channels + num_labels self.model = MultiResPatchDiscriminator( num_discriminators, kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type) print('Done with the Multi-resolution patch ' 'discriminator initialization.')
def __init__(self, dis_cfg, data_cfg): super(Discriminator, self).__init__() print('Multi-resolution patch discriminator initialization.') # We assume the first datum is the ground truth image. image_channels = get_paired_input_image_channel_number(data_cfg) # Calculate number of channels in the input label. if data_cfg.type == 'imaginaire.datasets.paired_videos': num_labels = get_paired_input_label_channel_number(data_cfg, video=True) else: num_labels = get_paired_input_label_channel_number(data_cfg) # Build the discriminator. kernel_size = getattr(dis_cfg, 'kernel_size', 3) num_filters = getattr(dis_cfg, 'num_filters', 128) max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) num_layers = getattr(dis_cfg, 'num_layers', 5) activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') print('\tBase filter number: %d' % num_filters) print('\tNumber of discriminators: %d' % num_discriminators) print('\tNumber of layers in a discriminator: %d' % num_layers) print('\tWeight norm type: %s' % weight_norm_type) num_input_channels = image_channels + num_labels self.discriminators = nn.ModuleList() for i in range(num_discriminators): net_discriminator = NLayerPatchDiscriminator( kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type) self.discriminators.append(net_discriminator) print('Done with the Multi-resolution patch ' 'discriminator initialization.') fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3) fpse_activation_norm_type = getattr(dis_cfg, 'fpse_activation_norm_type', 'none') self.fpse_discriminator = FPSEDiscriminator(image_channels, num_labels, num_filters, fpse_kernel_size, weight_norm_type, fpse_activation_norm_type)
def __init__(self, enc_cfg, data_cfg): super(Encoder, self).__init__() label_nc = get_paired_input_label_channel_number(data_cfg) feat_nc = enc_cfg.num_feat_channels n_clusters = getattr(enc_cfg, 'num_clusters', 10) for i in range(label_nc): dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32) self.register_buffer('cluster_%d' % i, torch.tensor(dummy_arr, dtype=torch.float32)) num_img_channels = get_paired_input_image_channel_number(data_cfg) self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3) num_filters = getattr(enc_cfg, 'num_filters', 64) num_downsamples = getattr(enc_cfg, 'num_downsamples', 4) weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none') activation_norm_type = getattr(enc_cfg, 'activation_norm_type', 'instance') padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect') base_conv_block = partial(Conv2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='relu') model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)] # Downsample. for i in range(num_downsamples): ch = num_filters * (2**i) model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)] # Upsample. for i in reversed(range(num_downsamples)): ch = num_filters * (2 ** i) model += [NearestUpsample(scale_factor=2), base_conv_block(ch * 2, ch, 3, padding=1)] model += [Conv2dBlock(num_filters, self.num_feat_channels, 7, padding=3, padding_mode=padding_mode, nonlinearity='tanh')] self.model = nn.Sequential(*model)
def __init__(self, gen_cfg, data_cfg): super().__init__() self.gen_cfg = gen_cfg self.data_cfg = data_cfg self.num_frames_G = data_cfg.num_frames_G # Number of residual blocks in generator. self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7) # Number of downsamplings for previous frame. self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4) # Number of filters in the first layer. self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32) self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024) self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3) padding = kernel_size // 2 # For pose dataset. self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset') if self.is_pose_data: pose_cfg = data_cfg.for_pose_dataset self.pose_type = getattr(pose_cfg, 'pose_type', 'both') self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False) # Input data params. self.num_input_channels = num_input_channels = get_paired_input_label_channel_number( data_cfg) num_img_channels = get_paired_input_image_channel_number(data_cfg) aug_cfg = data_cfg.val.augmentations if hasattr(aug_cfg, 'center_crop_h_w'): crop_h_w = aug_cfg.center_crop_h_w elif hasattr(aug_cfg, 'resize_h_w'): crop_h_w = aug_cfg.resize_h_w else: raise ValueError('Need to specify output size.') crop_h, crop_w = crop_h_w.split(',') crop_h, crop_w = int(crop_h), int(crop_w) # Spatial size at the bottle neck of generator. self.sh = crop_h // (2**num_layers) self.sw = crop_w // (2**num_layers) # Noise vector dimension. self.z_dim = getattr(gen_cfg, 'style_dims', 256) self.use_segmap_as_input = \ getattr(gen_cfg, 'use_segmap_as_input', False) # Label / image embedding network. self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None) self.use_embed = getattr(emb_cfg, 'use_embed', 'True') self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5) if self.use_embed: self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels) # Flow network. self.flow_cfg = flow_cfg = gen_cfg.flow # Use SPADE to combine warped and hallucinated frames instead of # linear combination. self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True) # Number of layers to perform multi-spade combine. self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, 'num_layers', 3) # At beginning of training, only train an image generator. self.temporal_initialized = False # Whether to output hallucinated frame (when training temporal network) # for additional loss. self.generate_raw_output = False # Image generation network. weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') activation_norm_type = gen_cfg.activation_norm_type activation_norm_params = gen_cfg.activation_norm_params if self.use_embed and \ not hasattr(activation_norm_params, 'num_filters'): activation_norm_params.num_filters = 0 nonlinearity = 'leakyrelu' self.base_res_block = base_res_block = partial( Res2dBlock, kernel_size=kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, nonlinearity=nonlinearity, order='NACNAC') # Upsampling residual blocks. for i in range(num_layers, -1, -1): activation_norm_params.cond_dims = self.get_cond_dims(i) activation_norm_params.partial = self.get_partial(i) if hasattr( self, 'get_partial') else False layer = base_res_block(self.get_num_filters(i + 1), self.get_num_filters(i)) setattr(self, 'up_%d' % i, layer) # Final conv layer. self.conv_img = Conv2dBlock(num_filters, num_img_channels, kernel_size, padding=padding, nonlinearity=nonlinearity, order='AC') num_filters = min(self.max_num_filters, num_filters * (2**(self.num_layers + 1))) if self.use_segmap_as_input: self.fc = Conv2dBlock(num_input_channels, num_filters, kernel_size=3, padding=1) else: self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw) # Misc. self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) self.upsample = partial(F.interpolate, scale_factor=2) self.init_temporal_network()
def __init__(self, flow_cfg, data_cfg): super().__init__() num_input_channels = get_paired_input_label_channel_number(data_cfg) num_prev_img_channels = get_paired_input_image_channel_number(data_cfg) num_frames = data_cfg.num_frames_G # Num. of input frames. self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32) self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024) num_downsamples = getattr(flow_cfg, 'num_downsamples', 5) kernel_size = getattr(flow_cfg, 'kernel_size', 3) padding = kernel_size // 2 self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6) # Multiplier on the flow output. self.flow_output_multiplier = getattr(flow_cfg, 'flow_output_multiplier', 20) activation_norm_type = getattr(flow_cfg, 'activation_norm_type', 'sync_batch') weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral') base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='leakyrelu') # Will downsample the labels and prev frames separately, then combine. down_lbl = [ base_conv_block(num_input_channels * num_frames, num_filters) ] down_img = [ base_conv_block(num_prev_img_channels * (num_frames - 1), num_filters) ] for i in range(num_downsamples): down_lbl += [ base_conv_block(self.get_num_filters(i), self.get_num_filters(i + 1), stride=2) ] down_img += [ base_conv_block(self.get_num_filters(i), self.get_num_filters(i + 1), stride=2) ] # Resnet blocks. res_flow = [] ch = self.get_num_filters(num_downsamples) for i in range(self.num_res_blocks): res_flow += [ Res2dBlock(ch, ch, kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, order='CNACN') ] # Upsample. up_flow = [] for i in reversed(range(num_downsamples)): up_flow += [ nn.Upsample(scale_factor=2), base_conv_block(self.get_num_filters(i + 1), self.get_num_filters(i)) ] conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)] conv_mask = [ Conv2dBlock(num_filters, 1, kernel_size, padding=padding, nonlinearity='sigmoid') ] self.down_lbl = nn.Sequential(*down_lbl) self.down_img = nn.Sequential(*down_img) self.res_flow = nn.Sequential(*res_flow) self.up_flow = nn.Sequential(*up_flow) self.conv_flow = nn.Sequential(*conv_flow) self.conv_mask = nn.Sequential(*conv_mask)
def init_temporal_network(self, cfg_init=None): r"""When starting training multiple frames, initialize the downsampling network and flow network. Args: cfg_init (dict) : Weight initialization config. """ # Number of image downsamplings for the previous frame. num_downsamples_img = self.num_downsamples_img # Number of residual blocks for the previous frame. self.num_res_blocks = int( np.ceil((self.num_layers - num_downsamples_img) / 2.0) * 2) # First conv layer. num_img_channels = get_paired_input_image_channel_number(self.data_cfg) self.down_first = \ Conv2dBlock(num_img_channels, self.num_filters, self.kernel_size, padding=self.kernel_size // 2) if cfg_init is not None: self.down_first.apply(weights_init(cfg_init.type, cfg_init.gain)) # Downsampling residual blocks. activation_norm_params = self.gen_cfg.activation_norm_params for i in range(num_downsamples_img + 1): activation_norm_params.cond_dims = self.get_cond_dims(i) layer = self.base_res_block(self.get_num_filters(i), self.get_num_filters(i + 1)) if cfg_init is not None: layer.apply(weights_init(cfg_init.type, cfg_init.gain)) setattr(self, 'down_%d' % i, layer) # Additional residual blocks. res_ch = self.get_num_filters(num_downsamples_img + 1) activation_norm_params.cond_dims = \ self.get_cond_dims(num_downsamples_img + 1) for i in range(self.num_res_blocks): layer = self.base_res_block(res_ch, res_ch) if cfg_init is not None: layer.apply(weights_init(cfg_init.type, cfg_init.gain)) setattr(self, 'res_%d' % i, layer) # Flow network. flow_cfg = self.flow_cfg self.temporal_initialized = True self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', False) and self.spade_combine self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg) if cfg_init is not None: self.flow_network_temp.apply( weights_init(cfg_init.type, cfg_init.gain)) self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True) if self.spade_combine: emb_cfg = flow_cfg.multi_spade_combine.embed num_img_channels = get_paired_input_image_channel_number( self.data_cfg) self.img_prev_embedding = LabelEmbedder(emb_cfg, num_img_channels + 1) if cfg_init is not None: self.img_prev_embedding.apply( weights_init(cfg_init.type, cfg_init.gain))
def __init__(self, gen_cfg, data_cfg): super(Generator, self).__init__() print('SPADE generator initialization.') # We assume the first datum is the ground truth image. image_channels = get_paired_input_image_channel_number(data_cfg) # Calculate number of channels in the input label. num_labels = get_paired_input_label_channel_number(data_cfg) crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations) # Build the generator out_image_small_side_size = crop_w if crop_w < crop_h else crop_h num_filters = getattr(gen_cfg, 'num_filters', 128) kernel_size = getattr(gen_cfg, 'kernel_size', 3) weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') cond_dims = 0 # Check whether we use the style code. style_dims = getattr(gen_cfg, 'style_dims', None) self.style_dims = style_dims if style_dims is not None: print('\tStyle code dimensions: %d' % style_dims) cond_dims += style_dims self.use_style = True else: self.use_style = False # Check whether we use the attribute code. if hasattr(gen_cfg, 'attribute_dims'): self.use_attribute = True self.attribute_dims = gen_cfg.attribute_dims cond_dims += gen_cfg.attribute_dims else: self.use_attribute = False if not self.use_style and not self.use_attribute: self.use_style_encoder = False else: self.use_style_encoder = True print('\tBase filter number: %d' % num_filters) print('\tConvolution kernel size: %d' % kernel_size) print('\tWeight norm type: %s' % weight_norm_type) skip_activation_norm = \ getattr(gen_cfg, 'skip_activation_norm', True) activation_norm_params = \ getattr(gen_cfg, 'activation_norm_params', None) if activation_norm_params is None: activation_norm_params = types.SimpleNamespace() if not hasattr(activation_norm_params, 'num_filters'): setattr(activation_norm_params, 'num_filters', 128) if not hasattr(activation_norm_params, 'kernel_size'): setattr(activation_norm_params, 'kernel_size', 3) if not hasattr(activation_norm_params, 'activation_norm_type'): setattr(activation_norm_params, 'activation_norm_type', 'sync_batch') if not hasattr(activation_norm_params, 'separate_projection'): setattr(activation_norm_params, 'separate_projection', False) if not hasattr(activation_norm_params, 'activation_norm_params'): activation_norm_params.activation_norm_params = \ types.SimpleNamespace() activation_norm_params.activation_norm_params.affine = True setattr(activation_norm_params, 'cond_dims', num_labels) if not hasattr(activation_norm_params, 'weight_norm_type'): setattr(activation_norm_params, 'weight_norm_type', weight_norm_type) global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch') use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True) print(activation_norm_params) self.spade_generator = SPADEGenerator( num_labels, out_image_small_side_size, image_channels, num_filters, kernel_size, cond_dims, activation_norm_params, weight_norm_type, global_adaptive_norm_type, skip_activation_norm, use_posenc_in_input_layer, self.use_style_encoder) if self.use_style: # Build the encoder. style_enc_cfg = getattr(gen_cfg, 'style_enc', None) if style_enc_cfg is None: style_enc_cfg = types.SimpleNamespace() if not hasattr(style_enc_cfg, 'num_filters'): setattr(style_enc_cfg, 'num_filters', 128) if not hasattr(style_enc_cfg, 'kernel_size'): setattr(style_enc_cfg, 'kernel_size', 3) if not hasattr(style_enc_cfg, 'weight_norm_type'): setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type) setattr(style_enc_cfg, 'input_image_channels', image_channels) setattr(style_enc_cfg, 'style_dims', style_dims) self.style_encoder = StyleEncoder(style_enc_cfg) self.z = None print('Done with the SPADE generator initialization.')