def init_temporal_network(self, ): flow_cfg = self.flow_cfg data_cfg = self.data_cfg emb_cfg = self.emb_cfg gen_cfg = self.gen_cfg print("initialize temporal network") self.sep_prev_flownet = flow_cfg.sep_prev_flow or ( self.num_frames_G != 2) or not flow_cfg.warp_ref if self.sep_prev_flownet: # False self.flow_network_temp = FlowGenerator(flow_cfg, data_cfg, self.num_frames_G) else: self.flow_network_temp = self.flow_network_ref self.sep_prev_embedding = emb_cfg.sep_warp_embed or not flow_cfg.warp_ref if self.sep_prev_embedding: # True num_img_channels = data_utils.get_paired_input_image_channel_number( self.data_cfg) self.prev_image_embedding = LabelEmbedding(gen_cfg, emb_cfg, data_cfg, num_img_channels + 1) else: self.prev_image_embedding = self.ref_image_embedding self.flow_temp_is_initalized = True self.temporal_initialized = True
def __init__(self, dis_cfg, data_cfg): super().__init__() self.data_cfg = data_cfg num_input_channels = get_paired_input_label_channel_number(data_cfg) # 6 if num_input_channels == 0: num_input_channels = getattr(data_cfg, 'label_channels', 1) num_img_channels = get_paired_input_image_channel_number(data_cfg) # 3 self.num_frames_D = data_cfg.num_frames_D # 2 self.num_scales = get_nested_attr(dis_cfg, "temporal.num_scales", 0) num_netD_input_channels = (num_input_channels + num_img_channels) self.use_few_shot = 'few_shot' in data_cfg.type # True if self.use_few_shot: num_netD_input_channels *= 2 self.net_D = MultiPatchDiscriminator(dis_cfg.image, num_netD_input_channels) self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None) if self.add_dis_cfg is not None: for name in self.add_dis_cfg: add_dis_cfg = self.add_dis_cfg[name] num_ch = num_img_channels * (2 if self.use_few_shot else 1) self.add_sublayer('net_D_' + name, MultiPatchDiscriminator(add_dis_cfg, num_ch)) # Temporal discriminator. self.num_netDT_input_channels = num_img_channels * self.num_frames_D for n in range(self.num_scales): self.add_sublayer('net_DT%d' % n, MultiPatchDiscriminator(dis_cfg.temporal, self.num_netDT_input_channels)) self.has_fg = getattr(data_cfg, 'has_foreground', False)
def __init__(self, dis_cfg, data_cfg): super(Discriminator, self).__init__() print('Multi-resolution patch discriminator initialization.') # We assume the first datum is the ground truth image. image_channels = get_paired_input_image_channel_number(data_cfg) # 3 # Calculate number of channels in the input label. num_labels = get_paired_input_label_channel_number(data_cfg) # 6 # Build the discriminator. kernel_size = getattr(dis_cfg, 'kernel_size', 3) num_filters = getattr(dis_cfg, 'num_filters', 128) max_num_filters = getattr(dis_cfg, 'max-num_filters', 512) num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) num_layers = getattr(dis_cfg, 'num_layers', 5) activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') print('\tBase filter number: %d' % num_filters) print('\tNumber of discriminators: %d' % num_discriminators) print('\tNumber of layers in a discriminator: %d' % num_layers) print('\tWeight norm type: %s' % weight_norm_type) num_input_channels = image_channels + num_labels self.model = MultiResPatchDiscriminator(num_discriminators, kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type) print("Done with the Multi-resolution patch discriminator initialization.")
def __init__(self, gen_cfg, data_cfg): super().__init__() hyper_cfg = gen_cfg.hyper num_filters = gen_cfg.num_filters # 32 self.max_num_filters = gen_cfg.max_num_filters # 1024 self.num_downsamples = num_downsamples = gen_cfg.num_downsamples # 5 self.num_filters_each_layer = num_filters_each_layer = \ [min(self.max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 2)] kernel_size = getattr(gen_cfg.activation_norm_params, 'kernel_size', 1) activation_norm_type = getattr(hyper_cfg, 'activation_norm_type', 'instance') # instance weight_norm_type = getattr(hyper_cfg, 'weight_norm_type', 'spectral') self.concat_ref_label = 'concat' in hyper_cfg.method_to_use_ref_labels self.mul_ref_label = 'mul' in hyper_cfg.method_to_use_ref_labels num_input_channels = data_utils.get_paired_input_label_channel_number( data_cfg) # 6: densepose + openpose if num_input_channels == 0: num_input_channels = getattr(data_cfg, 'label_channels', 1) elif misc_utils.get_nested_attr(data_cfg, 'for_pose_dataset.pose_type', 'both') == 'open': num_input_channels -= 3 data_cfg.num_input_channels = num_input_channels num_img_channels = data_utils.get_paired_input_image_channel_number( data_cfg) # 3 num_ref_channels = num_img_channels + ( num_input_channels if self.concat_ref_label else 0 ) # 3 for mul_ref conv_2d_block = partial( Conv2dBlock, kernel_size=kernel_size, padding=(kernel_size // 2), weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='leakyrelu', ) self.ref_img_first = conv_2d_block(num_ref_channels, num_filters) if self.mul_ref_label: self.ref_label_first = conv_2d_block(num_input_channels, num_filters) for i in range(num_downsamples): in_ch, out_ch = num_filters_each_layer[i], num_filters_each_layer[ i + 1] self.add_sublayer('ref_img_down_%d' % i, conv_2d_block(in_ch, out_ch, stride=2)) self.add_sublayer('ref_img_up_%d' % i, conv_2d_block(out_ch, in_ch)) if self.mul_ref_label: self.add_sublayer('ref_label_down_%d' % i, conv_2d_block(in_ch, out_ch, stride=2)) self.add_sublayer('ref_label_up_%d' % i, conv_2d_block(out_ch, in_ch)) if hasattr(hyper_cfg, 'attention'): self.num_downsample_atn = misc_utils.get_and_setattr( hyper_cfg.attention, 'num_downsamples', 2) # 2 if data_cfg.initial_few_shot_K > 1: self.attention_module = AttentionModule( hyper_cfg, data_cfg, conv_2d_block, num_filters_each_layer) else: self.num_downsample_atn = 0
def __init__(self, gen_cfg, data_cfg): super().__init__() self.gen_cfg = gen_cfg self.data_cfg = data_cfg self.flow_cfg = flow_cfg = gen_cfg.flow self.emb_cfg = emb_cfg = flow_cfg.multi_spade_combine.embed hyper_cfg = gen_cfg.hyper self.use_hyper_embed = hyper_cfg.is_hyper_embed # True self.num_frames_G = data_cfg.num_frames_G num_img_channels = data_utils.get_paired_input_image_channel_number( data_cfg) num_input_channels = data_utils.get_paired_input_label_channel_number( data_cfg) if num_input_channels == 0: num_input_channels = getattr(data_cfg, 'label_channels', 1) elif misc_utils.get_nested_attr(data_cfg, 'for_pose_dataset.pose_type', 'both') == 'open': num_input_channels -= 3 # Number of hyper layers self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, 'num_layers', 3) # Whether to generate raw output for additional losses. self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', False) # For pose dataset. self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset') if self.is_pose_data: pose_cfg = data_cfg.for_pose_dataset self.pose_type = getattr(pose_cfg, 'pose_type', 'both') self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False) self.main_generator = Generator(gen_cfg, data_cfg) self.reference_encoder = ReferenceEncoder(gen_cfg, data_cfg) self.weight_generator = WeightGenerator(gen_cfg, data_cfg) self.label_embedding = LabelEmbedding(gen_cfg, gen_cfg.embed, data_cfg, num_input_channels, num_hyper_layers=-1) # Flow estimation module. # Whether to warp reference image and combine with the synthesized. self.warp_ref = getattr(flow_cfg, 'warp_ref', True) # True if self.warp_ref: self.flow_network_ref = FlowGenerator(flow_cfg, data_cfg, 2) self.ref_image_embedding = LabelEmbedding(gen_cfg, emb_cfg, data_cfg, num_img_channels + 1) # At beginning of training, only train an image generator. # When starting training multiple frames, initialize the flow network. self.temporal_initialized = False if getattr(gen_cfg, 'init_temporal', False): self.init_temporal_network()
def __init__(self, gen_cfg, data_cfg): super().__init__() self.gen_cfg = gen_cfg self.data_cfg = data_cfg self.num_frames_G = data_cfg.num_frames_G # 2 self.flow_cfg = flow_cfg = gen_cfg.flow num_img_channels = data_utils.get_paired_input_image_channel_number(data_cfg) self.num_downsamples = num_downsamples = misc_utils.get_and_setattr(gen_cfg, 'num_downsamples', 5) conv_kernel_size = misc_utils.get_and_setattr(gen_cfg, 'kernel_size', 3) num_filters = misc_utils.get_and_setattr(gen_cfg, 'num_filters', 32) max_num_filters = misc_utils.get_and_setattr(gen_cfg, 'max_num_filters', 1024) self.max_num_filters = gen_cfg.max_num_filters = min(max_num_filters, num_filters * (2 ** num_downsamples)) # Get number of filters at each layer in the main branch num_filters_each_layer = [min(self.max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 2)] # Hyper normalization / convolution. hyper_cfg = gen_cfg.hyper # Use adaptive weight generation for SPADE self.use_hyper_spade = hyper_cfg.is_hyper_spade # True # Use adaptive for convolutional layers in the main branch. self.use_hyper_conv = hyper_cfg.is_hyper_conv # True # Number of hyper layers. self.num_hyper_layers = getattr(hyper_cfg, 'num_hyper_layers', 4) if self.num_hyper_layers == -1: self.num_hyper_layers = num_downsamples gen_cfg.hyper.num_hyper_layers = self.num_hyper_layers # Number of layers to perform multi-spade combine. self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, 'num_layers', 3) # Whether to generate raw output for additional losses. self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', False) # Main branch image generation. padding = conv_kernel_size // 2 activation_norm_type = misc_utils.get_and_setattr(gen_cfg, 'activation_norm_type', 'sync_batch') weight_norm_type = misc_utils.get_and_setattr(gen_cfg, 'weight_norm_type', 'spectral') activation_norm_params = misc_utils.get_and_setattr(gen_cfg, 'activation_norm_params', None) # spatially_adaptive spade_in_channels = [] # Input channel size in SPADE module. for i in range(num_downsamples + 1): spade_in_channels += [[num_filters_each_layer[i]]] \ if i >= self.num_multi_spade_layers else [[num_filters_each_layer[i]] * 3] order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC') for i in reversed(range(num_downsamples + 1)): # 5 -> 0 activation_norm_params.cond_dims = spade_in_channels[i] is_hyper_conv = self.use_hyper_conv and i < self.num_hyper_layers is_hyper_norm = self.use_hyper_spade and i < self.num_hyper_layers self.add_sublayer('up_%d' % i, HyperRes2dBlock( num_filters_each_layer[i + 1], num_filters_each_layer[i], conv_kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, order=order*2, is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm)) self.conv_img = Conv2dBlock(num_filters, num_img_channels, conv_kernel_size, padding=padding, nonlinearity='leakyrelu', order='AC') self.upsample = partial(L.image_resize, scale=2)
def __init__(self, flow_cfg, data_cfg, num_frames): super().__init__() num_input_channels = data_cfg.num_input_channels # 6 if num_input_channels == 0: num_input_channels = 1 num_prev_img_channels = data_utils.get_paired_input_image_channel_number( data_cfg) # 3 num_downsamples = getattr(flow_cfg, 'num_downsamples', 3) kernel_size = getattr(flow_cfg, 'kernel_size', 3) padding = kernel_size // 2 num_blocks = getattr(flow_cfg, 'num_blocks', 6) num_filters = getattr(flow_cfg, 'num_filters', 32) max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024) num_filters_each_layer = [ min(max_num_filters, num_filters * (2**i)) for i in range(num_downsamples + 1) ] self.flow_output_multiplier = getattr(flow_cfg, 'flow_output_multiplier', 20) self.sep_up_mask = getattr(flow_cfg, 'sep_up_mask', False) activation_norm_type = getattr(flow_cfg, 'activation_norm_type', 'sync_batch') weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral') base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='leakyrelu') num_input_channels = num_input_channels * num_frames + num_prev_img_channels * ( num_frames - 1) # First layer. down_flow = [('0', base_conv_block(num_input_channels, num_filters))] # Downsamples. for i in range(num_downsamples): down_flow += [(str(i + 1), base_conv_block(num_filters_each_layer[i], num_filters_each_layer[i + 1], stride=2))] # Resnet blocks. res_flow = [] ch = num_filters_each_layer[num_downsamples] for i in range(num_blocks): res_flow += [(str(i), Res2dBlock(ch, ch, kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, order='NACNAC'))] # Upsamples. up_flow = [] for i in reversed(range(num_downsamples)): up_flow += [(str( (num_downsamples - 1 - i) * 2), Upsample(scale=2)), (str((num_downsamples - 1 - i) * 2 + 1), base_conv_block(num_filters_each_layer[i + 1], num_filters_each_layer[i]))] conv_flow = [('0', Conv2dBlock(num_filters, 2, kernel_size, padding=padding))] conv_mask = [('0', Conv2dBlock(num_filters, 1, kernel_size, padding=padding, nonlinearity='sigmoid'))] self.down_flow = dg.Sequential(*down_flow) self.res_flow = dg.Sequential(*res_flow) self.up_flow = dg.Sequential(*up_flow) if self.sep_up_mask: self.up_mask = dg.Sequential(*copy.deepcopy(up_flow)) self.conv_flow = dg.Sequential(*conv_flow) self.conv_mask = dg.Sequential(*conv_mask)