def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if self.is_exporting or (len( device_config.devices) != 0 and not self.is_debug()) else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 256 self.face_type = { 'h': FaceType.HALF, 'mf': FaceType.MID_FULL, 'f': FaceType.FULL, 'wf': FaceType.WHOLE_FACE, 'head': FaceType.HEAD }[self.options['face_type']] place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name bgr_shape = nn.get4Dshape(resolution, resolution, 3) mask_shape = nn.get4Dshape(resolution, resolution, 1) # Initializing model classes self.model = XSegNet(name='XSeg', resolution=resolution, load_weights=not self.is_first_run(), weights_file_root=self.get_model_root_path(), training=True, place_model_on_cpu=place_model_on_cpu, optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), data_format=nn.data_format) self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_list = [] gpu_losses = [] gpu_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_input_t = self.model.input_t[batch_slice, :, :, :] gpu_target_t = self.model.target_t[ batch_slice, :, :, :] # process model tensors gpu_pred_logits_t, gpu_pred_t = self.model.flow( gpu_input_t, pretrain=self.pretrain) gpu_pred_list.append(gpu_pred_t) if self.pretrain: # Structural loss gpu_loss = tf.reduce_mean( 5 * nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution / 23.2)), axis=[1]) # Pixel loss gpu_loss += tf.reduce_mean( 10 * tf.square(gpu_target_t - gpu_pred_t), axis=[1, 2, 3]) else: gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1, 2, 3]) gpu_losses += [gpu_loss] gpu_loss_gvs += [ nn.gradients(gpu_loss, self.model.get_weights()) ] # Average losses and gradients, and create optimizer update ops #with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok with tf.device(models_opt_device): pred = tf.concat(gpu_pred_list, 0) loss = tf.concat(gpu_losses, 0) loss_gv_op = self.model.opt.get_update_op( nn.average_gv_list(gpu_loss_gvs)) # Initializing training and view functions if self.pretrain: def train(input_np, target_np): l, _ = nn.tf_sess.run( [loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l else: def train(input_np, target_np): l, _ = nn.tf_sess.run( [loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l self.train = train def view(input_np): return nn.tf_sess.run([pred], feed_dict={self.model.input_t: input_np}) self.view = view # initializing sample generators cpu_count = min(multiprocessing.cpu_count(), 8) src_dst_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if self.pretrain: pretrain_gen = SampleGeneratorFace( self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], uniform_yaw_distribution=False, generators_count=cpu_count) self.set_training_data_generators([pretrain_gen]) else: srcdst_generator = SampleGeneratorFaceXSeg( [self.training_data_src_path, self.training_data_dst_path], debug=self.is_debug(), batch_size=self.get_batch_size(), resolution=resolution, face_type=self.face_type, generators_count=src_dst_generators_count, data_format=nn.data_format) src_generator = SampleGeneratorFace( self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count, raise_on_no_data=False) dst_generator = SampleGeneratorFace( self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count, raise_on_no_data=False) self.set_training_data_generators( [srcdst_generator, src_generator, dst_generator])
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf self.resolution = resolution = self.options['resolution'] self.face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] eyes_prio = self.options['eyes_prio'] archi_split = self.options['archi'].split('-') if len(archi_split) == 2: archi_type, archi_opts = archi_split elif len(archi_split) == 1: archi_type, archi_opts = archi_split[0], None ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] random_warp = False if self.pretrain else self.options['random_warp'] if self.pretrain: self.options_show_override['gan_power'] = 0.0 self.options_show_override['random_warp'] = False self.options_show_override['lr_dropout'] = 'n' self.options_show_override['face_style_power'] = 0.0 self.options_show_override['bg_style_power'] = 0.0 self.options_show_override['uniform_yaw'] = True masked_training = self.options['masked_training'] import dfl dfl.load_config() masked_training = dfl.get_config("masked_training", "1") == "1" ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape) self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) self.target_src = tf.placeholder (nn.floatx, bgr_shape) self.target_dst = tf.placeholder (nn.floatx, bgr_shape) self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape) self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape) # Initializing model classes model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts) with tf.device (models_opt_device): if 'df' in archi_type: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: if self.options['true_face_power'] != 0: self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' ) self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] elif 'liae' in archi_type: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inters_out_ch = inter_AB_out_ch+inter_B_out_ch self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_AB, 'inter_AB.npy'], [self.inter_B , 'inter_B.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: if gan_power != 0: self.D_src = nn.UNetPatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src") self.model_filename_list += [ [self.D_src, 'D_src_v2.npy'] ] # Initialize optimizers lr=5e-5 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 if 'df' in archi_type: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() elif 'liae' in archi_type: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if self.options['true_face_power'] != 0: self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] if gan_power != 0: self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_v2_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:] gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:] # process model tensors if 'df' in archi_type: gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) elif 'liae' in archi_type: gpu_src_code = self.encoder (gpu_warped_src) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) gpu_dst_code = self.encoder (gpu_warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) # unpack masks from one combined mask gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1) gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1) gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1) gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) if resolution < 256: gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) else: gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) face_style_power = self.options['face_style_power'] / 100.0 if face_style_power != 0 and not self.pretrain: gpu_src_loss += nn.style_loss(gpu_psd_target_dst_style_masked, gpu_target_dst_style_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) bg_style_power = self.options['bg_style_power'] / 100.0 if bg_style_power != 0 and not self.pretrain: gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) if resolution < 256: gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) else: gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLoss(labels,logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) if self.options['true_face_power'] != 0: gpu_src_code_d = self.code_discriminator( gpu_src_code ) gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d) gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \ DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] if gan_power != 0: gpu_pred_src_src_d, \ gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt) gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2) gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2) gpu_target_src_d, \ gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2) gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \ (DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \ DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5 gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights() gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2)) gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) if self.options['true_face_power'] != 0: D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all, }) return s, d self.src_dst_train = src_dst_train if self.options['true_face_power'] != 0: def D_train(warped_src, warped_dst): nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) self.D_train = D_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all}) self.D_src_dst_train = D_src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): if 'df' in archi_type: gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) elif 'liae' in archi_type: gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if 'df' in archi_type: if model == self.inter: do_init = True elif 'liae' in archi_type: if model == self.inter_AB or model == self.inter_B: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True)
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" nn.initialize(data_format=self.model_data_format) tf = nn.tf self.resolution = resolution = self.options['resolution'] lowest_dense_res = self.lowest_dense_res = resolution // 32 class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size super().__init__(*kwargs) def on_build(self, *args, **kwargs ): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME') def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return self.out_ch class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp+x, 0.2) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch, ae_ch): self.down1 = Downscale(in_ch, e_ch, kernel_size=5) self.res1 = ResidualBlock(e_ch) self.down2 = Downscale(e_ch, e_ch*2, kernel_size=5) self.down3 = Downscale(e_ch*2, e_ch*4, kernel_size=5) self.down4 = Downscale(e_ch*4, e_ch*8, kernel_size=5) self.down5 = Downscale(e_ch*8, e_ch*8, kernel_size=5) self.res5 = ResidualBlock(e_ch*8) self.dense1 = nn.Dense( lowest_dense_res*lowest_dense_res*e_ch*8, ae_ch ) def forward(self, inp): x = inp x = self.down1(x) x = self.res1(x) x = self.down2(x) x = self.down3(x) x = self.down4(x) x = self.down5(x) x = self.res5(x) x = nn.flatten(x) x = nn.pixel_norm(x, axes=-1) x = self.dense1(x) return x class Inter(nn.ModelBase): def __init__(self, ae_ch, ae_out_ch, **kwargs): self.ae_ch, self.ae_out_ch = ae_ch, ae_out_ch super().__init__(**kwargs) def on_build(self): ae_ch, ae_out_ch = self.ae_ch, self.ae_out_ch self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) def forward(self, inp): x = inp x = self.dense2(x) x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch, d_mask_ch ): self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3) self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res1 = ResidualBlock(d_ch*8, kernel_size=3) self.res2 = ResidualBlock(d_ch*4, kernel_size=3) self.res3 = ResidualBlock(d_ch*2, kernel_size=3) self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME') self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') def forward(self, inp): z = inp x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) x = self.upscale3(x) x = self.res3(x) x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), self.out_conv1(x), self.out_conv2(x), self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) m = self.upscalem3(m) m = self.upscalem4(m) m = tf.nn.sigmoid(self.out_convm(m)) return x, m self.face_type = {'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] if 'eyes_prio' in self.options: self.options.pop('eyes_prio') eyes_mouth_prio = self.options['eyes_mouth_prio'] ae_dims = self.ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] morph_factor = self.options['morph_factor'] pretrain = self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] random_warp = False if self.pretrain else self.options['random_warp'] random_src_flip = self.random_src_flip if not self.pretrain else True random_dst_flip = self.random_dst_flip if not self.pretrain else True if self.pretrain: self.options_show_override['gan_power'] = 0.0 self.options_show_override['random_warp'] = False self.options_show_override['lr_dropout'] = 'n' self.options_show_override['uniform_yaw'] = True masked_training = self.options['masked_training'] ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') # Initializing model classes with tf.device (models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, ae_ch=ae_dims, name='encoder') self.inter_src = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_src') self.inter_dst = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_dst') self.decoder = Decoder(in_ch=ae_dims, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_src, 'inter_src.npy'], [self.inter_dst , 'inter_dst.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: if gan_power != 0: self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") self.model_filename_list += [ [self.GAN, 'GAN.npy'] ] # Initialize optimizers lr=5e-5 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.all_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() if pretrain: self.trainable_weights = self.encoder.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() else: self.trainable_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() self.src_dst_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if gan_power != 0: self.GAN_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() self.model_filename_list += [ (self.GAN_opt, 'GAN_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_GAN_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] # process model tensors gpu_src_code = self.encoder (gpu_warped_src) gpu_dst_code = self.encoder (gpu_warped_dst) if pretrain: gpu_src_inter_src_code = self.inter_src (gpu_src_code) gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) gpu_src_code = gpu_src_inter_src_code * nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor) gpu_dst_code = gpu_src_dst_code = gpu_dst_inter_dst_code * nn.random_binomial( [bs_per_gpu, gpu_dst_inter_dst_code.shape.as_list()[1], 1,1] , p=0.25) else: gpu_src_inter_src_code = self.inter_src (gpu_src_code) gpu_src_inter_dst_code = self.inter_dst (gpu_src_code) gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) inter_rnd_binomial = nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor) gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]), tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 gpu_target_dst_anti_masked = gpu_target_dst*(1.0-gpu_target_dstm_blur) gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst*gpu_target_dstm_blur if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur) gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur) if resolution < 256: gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) else: gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) if eyes_mouth_prio: gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_dst_loss += 0.1*tf.reduce_mean(tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) gpu_dst_losses += [gpu_dst_loss] if not pretrain: if resolution < 256: gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) else: gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_mouth_prio: gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) else: gpu_src_loss = gpu_dst_loss gpu_src_losses += [gpu_src_loss] if pretrain: gpu_G_loss = gpu_dst_loss else: gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLossOnes(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) def DLossZeros(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) if gan_power != 0: gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked_opt) gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked_opt) gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked_opt) gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked_opt) gpu_D_src_dst_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) ) * ( 1.0 / 8) gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.GAN.get_weights() ) ] gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) ) * gan_power if masked_training: # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device(f'/CPU:0'): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) with tf.device (models_opt_device): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) #GAN_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gvs) ) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em, }) return s, d self.src_dst_train = src_dst_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em}) self.D_src_dst_train = D_src_dst_train def AE_view(warped_src, warped_dst, morph_value): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_view = AE_view else: #Initializing merge function with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]), tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) def AE_merge(warped_dst, morph_value): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter_src or model == self.inter_dst: do_init = True else: do_init = self.is_first_run() if self.is_training and gan_power != 0 and model == self.GAN: if self.gan_model_changed: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() ############### # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=random_src_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True)
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len( devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf resolution = self.resolution = 96 self.face_type = FaceType.FULL ae_dims = 128 e_dims = 128 d_dims = 64 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all( [dev.total_mem_gb >= 4 for dev in devices]) models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) self.model_filename_list = [] model_archi = nn.DeepFakeArchi(resolution, mod='quick') with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.warped_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_srcm = tf.placeholder(nn.floatx, mask_shape) self.target_dstm = tf.placeholder(nn.floatx, mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels( (nn.floatx, bgr_shape)) self.inter = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') inter_out_ch = self.inter.compute_output_channels( (nn.floatx, (None, encoder_out_ch))) self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: self.src_dst_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, 4 // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_G_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init and self.pretrained_model_path is not None: pretrained_filepath = self.pretrained_model_path / filename if pretrained_filepath.exists(): do_init = not model.load_weights(pretrained_filepath) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() nn.initialize(data_format="NHWC") tf = nn.tf device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 256 self.face_type = FaceType.FULL place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' bgr_shape = nn.get4Dshape(resolution, resolution, 3) mask_shape = nn.get4Dshape(resolution, resolution, 1) # Initializing model classes self.model = TernausNet( f'FANSeg_{FaceType.toString(self.face_type)}', resolution, load_weights=not self.is_first_run(), weights_file_root=self.get_model_root_path(), training=True, place_model_on_cpu=place_model_on_cpu, optimizer=nn.RMSprop( lr=0.0001, lr_dropout=0.3 if self.options['lr_dropout'] else 1.0, name='opt')) if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_list = [] gpu_losses = [] gpu_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_input_t = self.model.input_t[batch_slice, :, :, :] gpu_target_t = self.model.target_t[ batch_slice, :, :, :] # process model tensors gpu_pred_logits_t, gpu_pred_t = self.model.net( [gpu_input_t]) gpu_pred_list.append(gpu_pred_t) gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1, 2, 3]) gpu_losses += [gpu_loss] gpu_loss_gvs += [ nn.gradients(gpu_loss, self.model.net_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred = nn.concat(gpu_pred_list, 0) loss = tf.reduce_mean(gpu_losses) loss_gv_op = self.model.opt.get_update_op( nn.average_gv_list(gpu_loss_gvs)) # Initializing training and view functions def train(input_np, target_np): l, _ = nn.tf_sess.run([loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l self.train = train def view(input_np): return nn.tf_sess.run([pred], feed_dict={self.model.input_t: input_np}) self.view = view # initializing sample generators training_data_src_path = self.training_data_src_path training_data_dst_path = self.training_data_dst_path cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 src_generators_count = int(src_generators_count * 1.5) src_generator = SampleGeneratorFace( training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode': 'lct', 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'random_motion_blur': (25, 5), 'random_gaussian_blur': (25, 5), 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count) dst_generator = SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count, raise_on_no_data=False) if not dst_generator.is_initialized(): io.log_info( f"\nTo view the model on unseen faces, place any aligned faces in {training_data_dst_path}.\n" ) self.set_training_data_generators([src_generator, dst_generator])
def on_initialize(self): nn.initialize() tf = nn.tf class EncBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D(in_ch, out_ch, kernel_size=3, padding='SAME') self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding='VALID' if self.zero_level else 'SAME') def forward(self, x): x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) if not self.zero_level: x = nn.max_pool(x) return x class DecBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding=3 if self.zero_level else 'SAME') self.conv2 = nn.Conv2D(out_ch, out_ch, kernel_size=3, padding='SAME') def forward(self, x): if not self.zero_level: x = nn.upsample2d(x) x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) return x class FromRGB(nn.ModelBase): def on_build(self, out_ch): self.conv1 = nn.Conv2D(3, out_ch, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.leaky_relu(self.conv1(x), 0.2) class ToRGB(nn.ModelBase): def on_build(self, in_ch): self.conv = nn.Conv2D(in_ch, 3, kernel_size=1, padding='SAME') self.convm = nn.Conv2D(in_ch, 1, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.sigmoid(self.conv(x)), tf.nn.sigmoid( self.convm(x)) class Encoder(nn.ModelBase): def on_build(self, e_ch, levels): self.enc_blocks = {} self.from_rgbs = {} self.dense_norm = nn.DenseNorm() in_ch = e_ch out_ch = in_ch for level in range(levels, -1, -1): self.max_ch = out_ch = np.clip(out_ch * 2, 0, 512) self.enc_blocks[level] = EncBlock(in_ch, out_ch, level) self.from_rgbs[level] = FromRGB(in_ch) in_ch = out_ch def forward(self, inp, stage): x = inp for level in range(stage, -1, -1): if stage in self.enc_blocks: if level == stage: x = self.from_rgbs[level](x) x = self.enc_blocks[level](x) x = nn.flatten(x) x = self.dense_norm(x) x = nn.reshape_4D(x, 1, 1, self.max_ch) return x def get_stage_weights(self, stage): self.get_weights() weights = [] for level in range(stage, -1, -1): if stage in self.enc_blocks: if level == stage: weights.append(self.from_rgbs[level].get_weights()) weights.append(self.enc_blocks[level].get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) class Decoder(nn.ModelBase): def on_build(self, d_ch, total_levels, levels_range): self.dec_blocks = {} self.to_rgbs = {} level_ch = {} ch = d_ch for level in range(total_levels, -2, -1): level_ch[level] = ch ch = np.clip(ch * 2, 0, 512) out_ch = level_ch[levels_range[1]] for level in range(levels_range[1], levels_range[0] - 1, -1): in_ch = level_ch[level - 1] self.dec_blocks[level] = DecBlock(in_ch, out_ch, level) self.to_rgbs[level] = ToRGB(out_ch) out_ch = in_ch def forward(self, inp, stage): x = inp for level in range(stage + 1): if level in self.dec_blocks: x = self.dec_blocks[level](x) if level == stage: x = self.to_rgbs[level](x) return x def get_stage_weights(self, stage): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = [] for level in range(stage + 1): if level in self.dec_blocks: weights.append(self.dec_blocks[level].get_weights()) if level == stage: weights.append(self.to_rgbs[level].get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.stage = stage = self.options['stage'] self.start_stage_iter = self.options.get('start_stage_iter', 0) self.target_stage_iter = self.options.get('target_stage_iter', 0) stage_resolutions = [2**(i + 2) for i in range(self.stage_max + 1)] stage_resolution = stage_resolutions[stage] ed_dims = 16 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4 models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_nc = 3 output_nc = 3 bgr_shape = (stage_resolution, stage_resolution, output_nc) mask_shape = (stage_resolution, stage_resolution, 1) self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.warped_dst = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.target_src = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.target_dst = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.target_srcm = tf.placeholder(tf.float32, (None, ) + mask_shape) self.target_dstm = tf.placeholder(tf.float32, (None, ) + mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(e_ch=ed_dims, levels=self.stage_max, name='encoder') self.inter = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[0, 2], name='inter') self.decoder_src = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[3, self.stage_max], name='decoder_src') self.decoder_dst = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[3, self.stage_max], name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: self.src_dst_all_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_stage_weights(stage) + self.inter.get_stage_weights(stage) \ + self.decoder_src.get_stage_weights(stage) \ + self.decoder_dst.get_stage_weights(stage) # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_all_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter( self.encoder(gpu_warped_src, stage), stage) gpu_dst_code = self.inter( self.encoder(gpu_warped_dst, stage), stage) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code, stage) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code, stage) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code, stage) import code code.interact(local=dict(globals(), **locals())) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_srcmasked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_src_dst_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): if gpu_count == 1: pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = gpu_pred_src_dstm_list[0] src_loss = gpu_src_losses[0] dst_loss = gpu_dst_losses[0] src_dst_loss_gv = gpu_src_dst_loss_gvs[0] else: pred_src_src = tf.concat(gpu_pred_src_src_list, 0) pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code, stage=stage) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.pretrain_just_disabled: if model == self.inter: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count - src_generators_count self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution }], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution }], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len( devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf class EncBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D(in_ch, out_ch, kernel_size=3, padding='SAME') self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding='VALID' if self.zero_level else 'SAME') def forward(self, x): x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) if not self.zero_level: x = nn.max_pool(x) #if self.zero_level: return x class DecBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding=3 if self.zero_level else 'SAME') self.conv2 = nn.Conv2D(out_ch, out_ch, kernel_size=3, padding='SAME') def forward(self, x): if not self.zero_level: x = nn.upsample2d(x) x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) return x class InterBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.dense1 = nn.Dense() def forward(self, x): x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) if not self.zero_level: x = nn.max_pool(x) #if self.zero_level: return x class FromRGB(nn.ModelBase): def on_build(self, out_ch): self.conv1 = nn.Conv2D(3, out_ch, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.leaky_relu(self.conv1(x), 0.2) class ToRGB(nn.ModelBase): def on_build(self, in_ch): self.conv = nn.Conv2D(in_ch, 3, kernel_size=1, padding='SAME') self.convm = nn.Conv2D(in_ch, 1, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.sigmoid(self.conv(x)), tf.nn.sigmoid( self.convm(x)) ed_dims = 16 ae_res = 4 level_chs = { i - 1: v for i, v in enumerate([ np.clip(ed_dims * (2**i), 0, 512) for i in range(self.stage_max + 2) ][::-1]) } ae_ch = level_chs[0] class Encoder(nn.ModelBase): def on_build(self, e_ch, levels): self.enc_blocks = {} self.from_rgbs = {} self.dense_norm = nn.DenseNorm() for level in range(levels, -1, -1): self.from_rgbs[level] = FromRGB(level_chs[level]) if level != 0: self.enc_blocks[level] = EncBlock( level_chs[level], level_chs[level - 1], level) self.ae_dense1 = nn.Dense(ae_res * ae_res * ae_ch, 256) self.ae_dense2 = nn.Dense(256, ae_res * ae_res * ae_ch) def forward(self, stage, inp, prev_inp=None, alpha=None): x = inp for level in range(stage, -1, -1): if stage in self.from_rgbs: if level == stage: x = self.from_rgbs[level](x) elif level == stage - 1: x = x * alpha + self.from_rgbs[level](prev_inp) * ( 1 - alpha) if level != 0: x = self.enc_blocks[level](x) x = nn.flatten(x) x = self.dense_norm(x) x = self.ae_dense1(x) x = self.ae_dense2(x) x = nn.reshape_4D(x, ae_res, ae_res, ae_ch) return x def get_stage_weights(self, stage): self.get_weights() weights = [] for level in range(stage, -1, -1): if stage in self.from_rgbs: if level == stage or level == stage - 1: weights.append(self.from_rgbs[level].get_weights()) if level != 0: weights.append( self.enc_blocks[level].get_weights()) weights.append(self.ae_dense1.get_weights()) weights.append(self.ae_dense2.get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) class Decoder(nn.ModelBase): def on_build(self, levels_range): self.dec_blocks = {} self.to_rgbs = {} for level in range(levels_range[0], levels_range[1] + 1): self.to_rgbs[level] = ToRGB(level_chs[level]) if level != 0: self.dec_blocks[level] = DecBlock( level_chs[level - 1], level_chs[level], level) def forward(self, stage, inp, alpha=None, inter=None): x = inp for level in range(stage + 1): if level in self.to_rgbs: if level == stage and stage > 0: prev_level = level - 1 #prev_x, prev_xm = (inter.to_rgbs[prev_level] if inter is not None and prev_level in inter.to_rgbs else self.to_rgbs[prev_level])(x) prev_x, prev_xm = self.to_rgbs[prev_level](x) prev_x = nn.upsample2d(prev_x) prev_xm = nn.upsample2d(prev_xm) if level != 0: x = self.dec_blocks[level](x) if level == stage: x, xm = self.to_rgbs[level](x) if stage > 0: x = x * alpha + prev_x * (1 - alpha) xm = xm * alpha + prev_xm * (1 - alpha) return x, xm return x def get_stage_weights(self, stage): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = [] for level in range(stage + 1): if level in self.to_rgbs: if level != 0: weights.append( self.dec_blocks[level].get_weights()) if level == stage or level == stage - 1: weights.append(self.to_rgbs[level].get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.stage = stage = self.options['stage'] self.start_stage_iter = self.options.get('start_stage_iter', 0) self.target_stage_iter = self.options.get('target_stage_iter', 0) resolution = self.options['resolution'] stage_resolutions = [2**(i + 2) for i in range(self.stage_max + 1)] stage_resolution = stage_resolutions[stage] prev_stage = stage - 1 if stage != 0 else stage prev_stage_resolution = stage_resolutions[ stage - 1] if stage != 0 else stage_resolution self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4 models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_nc = 3 output_nc = 3 prev_bgr_shape = nn.get4Dshape(prev_stage_resolution, prev_stage_resolution, output_nc) bgr_shape = nn.get4Dshape(stage_resolution, stage_resolution, output_nc) mask_shape = nn.get4Dshape(stage_resolution, stage_resolution, 1) self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.prev_warped_src = tf.placeholder(tf.float32, prev_bgr_shape) self.warped_src = tf.placeholder(tf.float32, bgr_shape) self.prev_warped_dst = tf.placeholder(tf.float32, prev_bgr_shape) self.warped_dst = tf.placeholder(tf.float32, bgr_shape) self.target_src = tf.placeholder(tf.float32, bgr_shape) self.target_dst = tf.placeholder(tf.float32, bgr_shape) self.target_srcm = tf.placeholder(tf.float32, mask_shape) self.target_dstm = tf.placeholder(tf.float32, mask_shape) self.alpha_t = tf.placeholder(tf.float32, (None, 1, 1, 1)) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(e_ch=ed_dims, levels=self.stage_max, name='encoder') #self.inter = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[0,2], name='inter') self.decoder_src = Decoder(levels_range=[0, self.stage_max], name='decoder_src') self.decoder_dst = Decoder(levels_range=[0, self.stage_max], name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy'], #[self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: self.src_dst_all_weights = self.encoder.get_weights( ) + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_stage_weights(stage) \ + self.decoder_src.get_stage_weights(stage) \ + self.decoder_dst.get_stage_weights(stage) # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=1.0, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_all_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_prev_warped_src = self.prev_warped_src[ batch_slice, :, :, :] gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_prev_warped_dst = self.prev_warped_dst[ batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] gpu_alpha_t = self.alpha_t[batch_slice, :, :, :] # process model tensors #gpu_src_code = self.inter(stage, self.encoder(stage, gpu_warped_src, gpu_prev_warped_src, gpu_alpha_t), gpu_alpha_t ) #gpu_dst_code = self.inter(stage, self.encoder(stage, gpu_warped_dst, gpu_prev_warped_dst, gpu_alpha_t), gpu_alpha_t ) gpu_src_code = self.encoder(stage, gpu_warped_src, gpu_prev_warped_src, gpu_alpha_t) gpu_dst_code = self.encoder(stage, gpu_warped_dst, gpu_prev_warped_dst, gpu_alpha_t) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( stage, gpu_src_code, gpu_alpha_t) #, inter=self.inter) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( stage, gpu_dst_code, gpu_alpha_t) #, inter=self.inter) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( stage, gpu_dst_code, gpu_alpha_t) #, inter=self.inter) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, stage_resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, stage_resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_srcmasked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * tf.square(gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) if stage_resolution >= 16: gpu_src_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 11.6)), axis=[1]) if stage_resolution >= 32: gpu_src_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 23.2)), axis=[1]) gpu_dst_loss = tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) if stage_resolution >= 16: gpu_dst_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 11.6)), axis=[1]) if stage_resolution >= 32: gpu_dst_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 23.2)), axis=[1]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_src_dst_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): if gpu_count == 1: pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = gpu_pred_src_dstm_list[0] src_loss = gpu_src_losses[0] dst_loss = gpu_dst_losses[0] src_dst_loss_gv = gpu_src_dst_loss_gvs[0] else: pred_src_src = tf.concat(gpu_pred_src_src_list, 0) pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def get_alpha(batch_size): alpha = 0 if self.stage != 0: alpha = (self.iter - self.start_stage_iter) / ( self.target_stage_iter - self.start_stage_iter) alpha = np.clip(alpha, 0, 1) alpha = np.array([alpha], nn.floatx.as_numpy_dtype).reshape( (1, 1, 1, 1)) alpha = np.repeat(alpha, batch_size, 0) return alpha def src_dst_train(prev_warped_src, warped_src, target_src, target_srcm, \ prev_warped_dst, warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.prev_warped_src: prev_warped_src, self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.prev_warped_dst: prev_warped_dst, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, self.alpha_t: get_alpha(prev_warped_src.shape[0]) }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(prev_warped_src, warped_src, prev_warped_dst, warped_dst): return nn.tf_sess.run( [ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.prev_warped_src: prev_warped_src, self.warped_src: warped_src, self.prev_warped_dst: prev_warped_dst, self.warped_dst: warped_dst, self.alpha_t: get_alpha(prev_warped_src.shape[0]) }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code, stage=stage) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.pretrain_just_disabled: if model == self.inter: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init: model.init_weights() # initializing sample generators if self.is_training: self.face_type = FaceType.FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count - src_generators_count self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, ], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, ], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len( device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf conv_kernel_initializer = nn.initializers.ca() class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=3, dilations=1, use_activator=True, *kwargs): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D( in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, mod=1, kernel_size=3): self.conv1 = nn.Conv2D( ch, ch * mod, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch * mod, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.1) x = self.conv2(x) x = inp + x x = tf.nn.leaky_relu(x, 0.1) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.conv1 = Downscale(in_ch, e_ch) self.conv2 = Downscale(e_ch, e_ch * 2) self.conv3 = Downscale(e_ch * 2, e_ch * 4) self.conv4 = Downscale(e_ch * 4, e_ch * 8) self.conv5 = Downscale(e_ch * 8, e_ch * 16) self.conv6 = Downscale(e_ch * 16, e_ch * 32) self.conv7 = Downscale(e_ch * 32, e_ch * 64) self.res1 = ResidualBlock(e_ch) self.res2 = ResidualBlock(e_ch * 2) self.res3 = ResidualBlock(e_ch * 4) self.res4 = ResidualBlock(e_ch * 8) self.res5 = ResidualBlock(e_ch * 16) self.res6 = ResidualBlock(e_ch * 32) self.res7 = ResidualBlock(e_ch * 64) def forward(self, inp): x = self.conv1(inp) x = self.res1(x) x = self.conv2(x) x = self.res2(x) x = self.conv3(x) x = self.res3(x) x = self.conv4(x) x = self.res4(x) x = self.conv5(x) x = self.res5(x) x = self.conv6(x) x = self.res6(x) x = self.conv7(x) x = self.res7(x) return x class Inter(nn.ModelBase): def __init__(self, in_ch, ae_ch, **kwargs): self.in_ch, self.ae_ch = in_ch, ae_ch super().__init__(**kwargs) def on_build(self): in_ch, ae_ch = self.in_ch, self.ae_ch self.dense_conv1 = nn.Conv2D( in_ch, 64, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.dense_conv2 = nn.Conv2D( 64, in_ch, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv7 = Upscale(in_ch, in_ch // 2) self.conv6 = Upscale(in_ch // 2, in_ch // 4) def forward(self, inp): x = inp x = self.dense_conv1(x) x = self.dense_conv2(x) x = self.conv7(x) x = self.conv6(x) return x class Decoder(nn.ModelBase): def on_build(self, in_ch): self.upscale6 = Upscale(in_ch, in_ch // 2) self.upscale5 = Upscale(in_ch // 2, in_ch // 4) self.upscale4 = Upscale(in_ch // 4, in_ch // 8) self.upscale3 = Upscale(in_ch // 8, in_ch // 16) self.upscale2 = Upscale(in_ch // 16, in_ch // 32) #self.upscale1 = Upscale(in_ch//32, in_ch//64) self.out_conv = nn.Conv2D( in_ch // 32, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.res61 = ResidualBlock(in_ch // 2, mod=8) self.res62 = ResidualBlock(in_ch // 2, mod=8) self.res63 = ResidualBlock(in_ch // 2, mod=8) self.res51 = ResidualBlock(in_ch // 4, mod=8) self.res52 = ResidualBlock(in_ch // 4, mod=8) self.res53 = ResidualBlock(in_ch // 4, mod=8) self.res41 = ResidualBlock(in_ch // 8, mod=8) self.res42 = ResidualBlock(in_ch // 8, mod=8) self.res43 = ResidualBlock(in_ch // 8, mod=8) self.res31 = ResidualBlock(in_ch // 16, mod=8) self.res32 = ResidualBlock(in_ch // 16, mod=8) self.res33 = ResidualBlock(in_ch // 16, mod=8) self.res21 = ResidualBlock(in_ch // 32, mod=8) self.res22 = ResidualBlock(in_ch // 32, mod=8) self.res23 = ResidualBlock(in_ch // 32, mod=8) m_ch = in_ch // 2 self.upscalem6 = Upscale(in_ch, m_ch // 2) self.upscalem5 = Upscale(m_ch // 2, m_ch // 4) self.upscalem4 = Upscale(m_ch // 4, m_ch // 8) self.upscalem3 = Upscale(m_ch // 8, m_ch // 16) self.upscalem2 = Upscale(m_ch // 16, m_ch // 32) #self.upscalem1 = Upscale(m_ch//32, m_ch//64) self.out_convm = nn.Conv2D( m_ch // 32, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): z = inp x = self.upscale6(z) x = self.res61(x) x = self.res62(x) x = self.res63(x) x = self.upscale5(x) x = self.res51(x) x = self.res52(x) x = self.res53(x) x = self.upscale4(x) x = self.res41(x) x = self.res42(x) x = self.res43(x) x = self.upscale3(x) x = self.res31(x) x = self.res32(x) x = self.res33(x) x = self.upscale2(x) x = self.res21(x) x = self.res22(x) x = self.res23(x) #x = self.upscale1 (x) y = self.upscalem6(z) y = self.upscalem5(y) y = self.upscalem4(y) y = self.upscalem3(y) y = self.upscalem2(y) #y = self.upscalem1 (y) return tf.nn.sigmoid(self.out_conv(x)), \ tf.nn.sigmoid(self.out_convm(y)) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices resolution = self.resolution = 128 ae_dims = 128 e_dims = 16 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all( [dev.total_mem_gb >= 4 for dev in devices]) models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 output_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.warped_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_srcm = tf.placeholder(nn.floatx, mask_shape) self.target_dstm = tf.placeholder(nn.floatx, mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') self.inter = Inter(in_ch=e_dims * 64, ae_ch=ae_dims, name='inter') self.decoder_src = Decoder(in_ch=e_dims * 16, name='decoder_src') self.decoder_dst = Decoder(in_ch=e_dims * 16, name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: self.src_dst_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=5e-5, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, 4 // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_G_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init and self.pretrained_model_path is not None: pretrained_filepath = self.pretrained_model_path / filename if pretrained_filepath.exists(): do_init = not model.load_weights(pretrained_filepath) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'data_format': nn.data_format, 'resolution': resolution }], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'data_format': nn.data_format, 'resolution': resolution }], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" nn.initialize(data_format=self.model_data_format) tf = nn.tf input_ch=3 resolution = self.resolution = self.options['resolution'] e_dims = self.options['e_dims'] ae_dims = self.options['ae_dims'] inter_dims = self.inter_dims = self.options['inter_dims'] inter_res = self.inter_res = resolution // 32 d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] face_type = self.face_type = {'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] morph_factor = self.options['morph_factor'] gan_power = self.gan_power = self.options['gan_power'] random_warp = self.options['random_warp'] blur_out_mask = self.options['blur_out_mask'] ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None use_fp16 = False if self.is_exporting: use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') conv_dtype = tf.float16 if use_fp16 else tf.float32 class Downscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=5 ): self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) def forward(self, x): return tf.nn.leaky_relu(self.conv1(x), 0.1) class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, x): x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp+x, 0.2) return x class Encoder(nn.ModelBase): def on_build(self): self.down1 = Downscale(input_ch, e_dims, kernel_size=5) self.res1 = ResidualBlock(e_dims) self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) self.res5 = ResidualBlock(e_dims*8) self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) def forward(self, x): if use_fp16: x = tf.cast(x, tf.float16) x = self.down1(x) x = self.res1(x) x = self.down2(x) x = self.down3(x) x = self.down4(x) x = self.down5(x) x = self.res5(x) if use_fp16: x = tf.cast(x, tf.float32) x = nn.pixel_norm(nn.flatten(x), axes=-1) x = self.dense1(x) return x class Inter(nn.ModelBase): def on_build(self): self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) def forward(self, inp): x = inp x = self.dense2(x) x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) return x class Decoder(nn.ModelBase): def on_build(self ): self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) self.res0 = ResidualBlock(d_dims*8, kernel_size=3) self.res1 = ResidualBlock(d_dims*8, kernel_size=3) self.res2 = ResidualBlock(d_dims*4, kernel_size=3) self.res3 = ResidualBlock(d_dims*2, kernel_size=3) self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) def forward(self, z): if use_fp16: z = tf.cast(z, tf.float16) x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) x = self.upscale3(x) x = self.res3(x) x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), self.out_conv1(x), self.out_conv2(x), self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) m = self.upscalem3(m) m = self.upscalem4(m) m = tf.nn.sigmoid(self.out_convm(m)) if use_fp16: x = tf.cast(x, tf.float32) m = tf.cast(m, tf.float32) return x, m models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') # Initializing model classes with tf.device (models_opt_device): self.encoder = Encoder(name='encoder') self.inter_src = Inter(name='inter_src') self.inter_dst = Inter(name='inter_dst') self.decoder = Decoder(name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_src, 'inter_src.npy'], [self.inter_dst , 'inter_dst.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: # Initialize optimizers clipnorm = 1.0 if self.options['clipgrad'] else 0.0 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0 self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() #if random_warp: # self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights() self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if gan_power != 0: self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ [self.GAN, 'GAN.npy'], [self.GAN_opt, 'GAN_opt.npy'] ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gradients = [] gpu_GAN_loss_gradients = [] def DLossOnes(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) def DLossZeros(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) for gpu_id in range(gpu_count): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] # process model tensors gpu_src_code = self.encoder (gpu_warped_src) gpu_dst_code = self.encoder (gpu_warped_dst) gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) inter_dims_bin = int(inter_dims*morph_factor) with tf.device(f'/CPU:0'): inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_anti = 1-gpu_target_srcm gpu_target_dstm_anti = 1-gpu_target_dstm gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur if blur_out_mask: sigma = resolution / 128 x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur # Structural loss gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) # Pixel loss gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3]) # Eyes+mouth prio loss gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3]) # Mask loss gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss # dst-dst background weak loss gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) if gan_power != 0: gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked) gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked) gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) ) * (1.0 / 8) gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) ) * gan_power # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device(f'/CPU:0'): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) with tf.device (models_opt_device): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) if gan_power != 0: GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) # Initializing training and view functions def train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em, }) return s, d self.train = train if gan_power != 0: def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em}) self.GAN_train = GAN_train def AE_view(warped_src, warped_dst, morph_value): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_view = AE_view else: #Initializing merge function with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) def AE_merge(warped_dst, morph_value): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.is_training and gan_power != 0 and model == self.GAN: if self.gan_model_changed: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() ############### # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = []
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len( device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf conv_kernel_initializer = nn.initializers.ca() class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) if self.subpixel: x = nn.space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch * (min(2**i, 8)) self.downs.append( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel)) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D( in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp + x, 0.2) return x class UpdownResidualBlock(nn.ModelBase): def on_build(self, ch, inner_ch, kernel_size=3): self.up = Upscale(ch, inner_ch, kernel_size=kernel_size) self.res = ResidualBlock(inner_ch, kernel_size=kernel_size) self.down = Downscale(inner_ch, ch, kernel_size=kernel_size, use_activator=False) def forward(self, inp): x = self.up(inp) x = upx = self.res(x) x = self.down(x) x = x + inp x = tf.nn.leaky_relu(x, 0.2) return x, upx class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=5, kernel_size=5, dilations=1, subpixel=False) def forward(self, inp): x = nn.flatten(self.down1(inp)) return x class Inter(nn.ModelBase): def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs): self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch super().__init__(**kwargs) def on_build(self): in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch self.dense1 = nn.Dense(in_ch, ae_ch) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch * 2) def forward(self, inp): x = self.dense1(inp) x = self.dense2(x) x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = self.upscale1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch, d_mask_ch): self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3) self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3) self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3) self.upscale3 = Upscale(d_ch * 2, d_ch * 1, kernel_size=3) self.res0 = ResidualBlock(d_ch * 8, kernel_size=3) self.res1 = ResidualBlock(d_ch * 4, kernel_size=3) self.res2 = ResidualBlock(d_ch * 2, kernel_size=3) self.res3 = ResidualBlock(d_ch * 1, kernel_size=3) self.out_conv = nn.Conv2D( d_ch * 1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.upscalem0 = Upscale(in_ch, d_mask_ch * 8, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch * 8, d_mask_ch * 4, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch * 4, d_mask_ch * 2, kernel_size=3) self.upscalem3 = Upscale(d_mask_ch * 2, d_mask_ch * 1, kernel_size=3) self.out_convm = nn.Conv2D( d_mask_ch * 1, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) """ def get_weights_ex(self, include_mask): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \ + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights() if include_mask: weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \ + self.out_convm.get_weights() return weights """ def get_weights_ex(self, include_mask): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() + self.upscale3.get_weights()\ + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.res3.get_weights() + self.out_conv.get_weights() if include_mask: weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() + self.upscalem3.get_weights() \ + self.out_convm.get_weights() return weights def forward(self, inp): z = inp x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) x = self.upscale3(x) x = self.res3(x) m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) m = self.upscalem3(m) return tf.nn.sigmoid(self.out_conv(x)), \ tf.nn.sigmoid(self.out_convm(m)) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 448 self.learn_mask = learn_mask = True eyes_prio = True ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = False self.pretrain_just_disabled = False if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = self.options[ 'gan_power'] if not self.pretrain else 0.0 masked_training = True models_opt_on_gpu = False if len(devices) == 0 else True if len( devices) > 1 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 output_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) lowest_dense_res = resolution // 32 self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.warped_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_srcm_all = tf.placeholder(nn.floatx, mask_shape) self.target_dstm_all = tf.placeholder(nn.floatx, mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels( (nn.floatx, bgr_shape)) self.inter = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') inter_out_ch = self.inter.compute_output_channels( (nn.floatx, (None, encoder_out_ch))) self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: if gan_power != 0: self.D_src = nn.PatchDiscriminator(patch_size=resolution // 16, in_ch=output_ch, base_ch=256, name="D_src") self.D_dst = nn.PatchDiscriminator(patch_size=resolution // 16, in_ch=output_ch, base_ch=256, name="D_dst") self.model_filename_list += [[self.D_src, 'D_src.npy']] self.model_filename_list += [[self.D_dst, 'D_dst.npy']] # Initialize optimizers lr = 5e-5 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.src_dst_opt = nn.RMSprop(lr=lr, clipnorm=clipnorm, name='src_dst_opt') self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] self.src_dst_all_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights_ex( learn_mask) + self.decoder_dst.get_weights_ex(learn_mask) self.src_dst_opt.initialize_variables( self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) if gan_power != 0: self.D_src_dst_opt = nn.RMSprop(lr=lr, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt.initialize_variables( self.D_src.get_weights() + self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.D_src_dst_opt, 'D_src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm_all = self.target_srcm_all[ batch_slice, :, :, :] gpu_target_dstm_all = self.target_dstm_all[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) # unpack masks from one combined mask gpu_target_srcm = tf.clip_by_value(gpu_target_srcm_all, 0, 1) gpu_target_dstm = tf.clip_by_value(gpu_target_dstm_all, 0, 1) gpu_target_srcm_eyes = tf.clip_by_value( gpu_target_srcm_all - 1, 0, 1) gpu_target_dstm_eyes = tf.clip_by_value( gpu_target_dstm_all - 1, 0, 1) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) if eyes_prio: gpu_src_loss += tf.reduce_mean( 300 * tf.abs(gpu_target_src * gpu_target_srcm_eyes - gpu_pred_src_src * gpu_target_srcm_eyes), axis=[1, 2, 3]) if learn_mask: gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) if eyes_prio: gpu_dst_loss += tf.reduce_mean( 300 * tf.abs(gpu_target_dst * gpu_target_dstm_eyes - gpu_pred_dst_dst * gpu_target_dstm_eyes), axis=[1, 2, 3]) if learn_mask: gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLoss(labels, logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits), axis=[1, 2, 3]) if gan_power != 0: gpu_pred_src_src_d = self.D_src( gpu_pred_src_src_masked_opt) gpu_pred_src_src_d_ones = tf.ones_like( gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like( gpu_pred_src_src_d) gpu_target_src_d = self.D_src( gpu_target_src_masked_opt) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_pred_dst_dst_d = self.D_dst( gpu_pred_dst_dst_masked_opt) gpu_pred_dst_dst_d_ones = tf.ones_like( gpu_pred_dst_dst_d) gpu_pred_dst_dst_d_zeros = tf.zeros_like( gpu_pred_dst_dst_d) gpu_target_dst_d = self.D_dst( gpu_target_dst_masked_opt) gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d) gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \ (DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \ DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5 gpu_D_src_dst_loss_gvs += [ nn.gradients( gpu_D_src_dst_loss, self.D_src.get_weights() + self.D_dst.get_weights()) ] gpu_G_loss += gan_power * ( DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d)) gpu_G_loss_gvs += [ nn.gradients(gpu_G_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( nn.average_gv_list(gpu_G_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op( nn.average_gv_list(gpu_D_src_dst_loss_gvs)) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm_all: target_srcm_all, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm_all: target_dstm_all, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): nn.tf_sess.run( [src_D_src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm_all: target_srcm_all, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm_all: target_dstm_all }) self.D_src_dst_train = D_src_dst_train if learn_mask: def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) else: def AE_view(warped_src, warped_dst): return nn.tf_sess.run( [pred_src_src, pred_dst_dst, pred_src_dst], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) if learn_mask: def AE_merge(warped_dst): return nn.tf_sess.run([ gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm ], feed_dict={ self.warped_dst: warped_dst }) else: def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_HEAD training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) t_img_warped = t.IMG_WARPED_TRANSFORMED cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'types': (t_img_warped, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'types': (t_img_warped, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count) ]) if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True)
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf self.resolution = resolution = self.options['resolution'] self.face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ] eyes_prio = self.options['eyes_prio'] archi = self.options['archi'] is_hd = 'hd' in archi ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0 masked_training = self.options['masked_training'] ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape) self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) self.src_code_in = tf.placeholder (nn.floatx, (None,256) ) self.target_src = tf.placeholder (nn.floatx, bgr_shape) self.target_dst = tf.placeholder (nn.floatx, bgr_shape) self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape) self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape) # Initializing model classes model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None) with tf.device (models_opt_device): if 'df' in archi: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter') inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: if self.options['true_face_power'] != 0: self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' ) self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] elif 'liae' in archi: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B') inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inters_out_ch = inter_AB_out_ch+inter_B_out_ch self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_AB, 'inter_AB.npy'], [self.inter_B , 'inter_B.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: if gan_power != 0: self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src") self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst") self.model_filename_list += [ [self.D_src, 'D_src.npy'] ] self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ] # Initialize optimizers lr=5e-5 lr_dropout = 0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if 'df' in archi: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() elif 'liae' in archi: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) if self.options['true_face_power'] != 0: self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] if gan_power != 0: self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_src_latent_code_list = [] gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_src_code_in = self.src_code_in[batch_slice,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:] gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:] # process model tensors if 'df' in archi: gpu_src_latent_code = self.inter.dense1(self.encoder(gpu_warped_src)) gpu_src_in_code = self.inter.fd(gpu_src_code_in) gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_in_code) #gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) elif 'liae' in archi: gpu_src_code = self.encoder (gpu_warped_src) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) gpu_dst_code = self.encoder (gpu_warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_src_latent_code_list.append(gpu_src_latent_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) # unpack masks from one combined mask gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1) gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1) gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1) gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) face_style_power = self.options['face_style_power'] / 100.0 if face_style_power != 0 and not self.pretrain: gpu_src_loss += nn.style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) bg_style_power = self.options['bg_style_power'] / 100.0 if bg_style_power != 0 and not self.pretrain: gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] ) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLoss(labels,logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) if self.options['true_face_power'] != 0: gpu_src_code_d = self.code_discriminator( gpu_src_code ) gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d) gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \ DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] if gan_power != 0: gpu_pred_src_src_d = self.D_src(gpu_pred_src_src_masked_opt) gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) gpu_target_src_d = self.D_src(gpu_target_src_masked_opt) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_pred_dst_dst_d = self.D_dst(gpu_pred_dst_dst_masked_opt) gpu_pred_dst_dst_d_ones = tf.ones_like (gpu_pred_dst_dst_d) gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d) gpu_target_dst_d = self.D_dst(gpu_target_dst_masked_opt) gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d) gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \ (DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \ DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5 gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ] gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d)) gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): src_latent_code = nn.concat(gpu_src_latent_code_list, 0) pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) if self.options['true_face_power'] != 0: D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all, }) return s, d self.src_dst_train = src_dst_train if self.options['true_face_power'] != 0: def D_train(warped_src, warped_dst): nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) self.D_train = D_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all}) self.D_src_dst_train = D_src_dst_train def AE_get_latent(warped_src): return nn.tf_sess.run ( src_latent_code, feed_dict={self.warped_src:warped_src}) self.AE_get_latent = AE_get_latent def AE_view_src(warped_src, src_code_in): return nn.tf_sess.run ( pred_src_src, feed_dict={self.warped_src:warped_src, self.src_code_in:src_code_in }) self.AE_view_src = AE_view_src def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): if 'df' in archi: gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) elif 'liae' in archi: gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if 'df' in archi: if model == self.inter: do_init = True elif 'liae' in archi: if model == self.inter_AB: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.PITCH_YAW_ROLL, 'resolution': resolution}, ], generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True) class PRD(nn.ModelBase): def on_build(self, ae_ch): self.dense1 = nn.Dense( ae_ch+1, 1024 ) self.dense2 = nn.Dense( 1024, 2048 ) self.dense3 = nn.Dense( 2048, 4096 ) self.dense4 = nn.Dense( 4096, 4096 ) self.dense5 = nn.Dense( 4096, ae_ch ) def forward(self, inp, yaw_in): x = tf.concat( [inp, yaw_in], -1 ) x = self.dense1(x) x = self.dense2(x) x = self.dense3(x) x = self.dense4(x) x = self.dense5(x) return x with tf.device( f'/GPU:0'): prd_model = PRD(256, name='PRD') prd_model.init_weights() prd_in = tf.placeholder (nn.floatx, (None,256) ) prd_targ = tf.placeholder (nn.floatx, (None,256) ) yaw_diff_in = tf.placeholder (nn.floatx, (None,1) ) prd_out = prd_model(prd_in, yaw_diff_in) loss = tf.reduce_sum ( tf.abs (prd_out - prd_targ) ) loss_gvs = nn.gradients (loss, prd_model.get_weights() ) prd_opt = nn.RMSprop(lr=5e-6, lr_dropout=0.3, name='prd_opt') prd_opt.initialize_variables(prd_model.get_weights()) prd_opt.init_weights() loss_gv_op = prd_opt.get_update_op (loss_gvs) s_gen, _ = self.get_training_data_generators() bs = self.get_batch_size() for n in range(1000): warped_src, target_src, target_srcm_all, src_pyr = s_gen.generate_next() sl = self.AE_get_latent(target_src) prd_in_np = [] prd_targ_np = [] yaw_diff_in_np = [] for i in range(bs): prd_in_np += [sl[i]] j = i while j == i: j = np.random.randint(bs) prd_targ_np += [ sl[j] ] yaw_diff_in_np += [ np.float32( [ src_pyr[j][1]-src_pyr[i][1] ] ) ] prd_loss, _ = nn.tf_sess.run([loss, loss_gv_op], feed_dict={prd_in:prd_in_np, prd_targ:prd_targ_np, yaw_diff_in:yaw_diff_in_np} ) print(f'{n} loss = {prd_loss}') warped_src, target_src, target_srcm_all, src_pyr = s_gen.generate_next() sl = self.AE_get_latent(target_src) yaw_diff_in_np = np.float32( [ [-0.4] ] *bs ) new_sl = nn.tf_sess.run(prd_out, feed_dict={prd_in:sl, yaw_diff_in:yaw_diff_in_np} ) new_target_src = self.AE_view_src( target_src, new_sl ) target_src = np.clip( nn.to_data_format( target_src ,"NHWC", self.model_data_format), 0.0, 1.0) new_target_src = np.clip( nn.to_data_format( new_target_src ,"NHWC", self.model_data_format), 0.0, 1.0) for i in range(bs): screen = np.concatenate ( (target_src[i], new_target_src[i]), 1 ) cv2.imshow("", (screen*255).astype(np.uint8) ) cv2.waitKey(0) import code code.interact(local=dict(globals(), **locals()))
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len( devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf resolution = self.resolution = 96 self.face_type = FaceType.FULL ae_dims = 256 e_dims = 64 d_dims = 64 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all( [dev.total_mem_gb >= 4 for dev in devices]) models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) self.model_filename_list = [] kernel_initializer = tf.initializers.glorot_uniform() class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D(in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3): self.conv1 = nn.Conv2D(ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) self.conv2 = nn.Conv2D(ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp + x, 0.2) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down11 = nn.Conv2D(in_ch, e_ch, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down12 = nn.Conv2D(e_ch, e_ch, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down21 = nn.Conv2D(e_ch, e_ch * 2, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down22 = nn.Conv2D(e_ch * 2, e_ch * 2, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down31 = nn.Conv2D(e_ch * 2, e_ch * 4, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down32 = nn.Conv2D(e_ch * 4, e_ch * 4, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down33 = nn.Conv2D(e_ch * 4, e_ch * 4, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down41 = nn.Conv2D(e_ch * 4, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down42 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down43 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down51 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down52 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down53 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, inp): x = inp x = self.down11(x) x = self.down12(x) x = nn.max_pool(x) x = self.down21(x) x = self.down22(x) x = nn.max_pool(x) x = self.down31(x) x = self.down32(x) x = self.down33(x) x = nn.max_pool(x) x = self.down41(x) x = self.down42(x) x = self.down43(x) x = nn.max_pool(x) x = self.down51(x) x = self.down52(x) x = self.down53(x) x = nn.max_pool(x) x = nn.flatten(x) return x class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs): self.conv1 = nn.Conv2D(self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=kernel_initializer) def forward(self, x): x = self.conv1(x) if self.subpixel: x = nn.space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch * (min(2**i, 8)) self.downs.append( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel)) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D(in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) def forward(self, inp): x = nn.flatten(self.down1(inp)) return x class Branch(nn.ModelBase): def on_build(self, in_ch, ae_ch): self.dense1 = nn.Dense(in_ch, ae_ch) def forward(self, inp): x = self.dense1(inp) return x class Classifier(nn.ModelBase): def on_build(self, in_ch, n_classes): self.dense1 = nn.Dense(in_ch, 4096) self.dense2 = nn.Dense(4096, 4096) self.pitch_dense = nn.Dense(4096, n_classes) self.yaw_dense = nn.Dense(4096, n_classes) def forward(self, inp): x = inp x = self.dense1(x) x = self.dense2(x) return self.pitch_dense(x), self.yaw_dense(x) lowest_dense_res = resolution // 16 class Inter(nn.ModelBase): def on_build(self, in_ch, ae_out_ch): self.ae_out_ch = ae_out_ch self.dense2 = nn.Dense( in_ch, lowest_dense_res * lowest_dense_res * ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch) def forward(self, inp): x = inp x = self.dense2(x) x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = self.upscale1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch, d_mask_ch): self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3) self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3) self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3) self.res0 = ResidualBlock(d_ch * 8, kernel_size=3) self.res1 = ResidualBlock(d_ch * 4, kernel_size=3) self.res2 = ResidualBlock(d_ch * 2, kernel_size=3) self.out_conv = nn.Conv2D( d_ch * 2, 3, kernel_size=1, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, inp): z = inp x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) return tf.nn.sigmoid(self.out_conv(x)) n_pyr_degs = self.n_pyr_degs = 3 n_pyr_classes = self.n_pyr_classes = 180 // self.n_pyr_degs with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.pitches_vector = tf.placeholder(nn.floatx, (None, n_pyr_classes)) self.yaws_vector = tf.placeholder(nn.floatx, (None, n_pyr_classes)) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels( (nn.floatx, bgr_shape)) self.bT = Branch(in_ch=encoder_out_ch, ae_ch=ae_dims, name='bT') self.bP = Branch(in_ch=encoder_out_ch, ae_ch=ae_dims, name='bP') self.bTC = Classifier(in_ch=ae_dims, n_classes=self.n_pyr_classes, name='bTC') self.bPC = Classifier(in_ch=ae_dims, n_classes=self.n_pyr_classes, name='bPC') self.inter = Inter(in_ch=ae_dims * 2, ae_out_ch=ae_dims * 2, name='inter') self.decoder = Decoder(in_ch=ae_dims * 2, d_ch=d_dims, d_mask_ch=d_dims, name='decoder') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.bT, 'bT.npy'], [self.bTC, 'bTC.npy'], [self.bP, 'bP.npy'], [self.bPC, 'bPC.npy'], [self.inter, 'inter.npy'], [self.decoder, 'decoder.npy']] if self.is_training: self.all_trainable_weights = self.encoder.get_weights() + \ self.bT.get_weights() +\ self.bTC.get_weights() +\ self.bP.get_weights() +\ self.bPC.get_weights() +\ self.inter.get_weights() +\ self.decoder.get_weights() # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=5e-5, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, 32 // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_list = [] gpu_pred_dst_list = [] gpu_A_losses = [] gpu_B_losses = [] gpu_C_losses = [] gpu_D_losses = [] gpu_A_loss_gvs = [] gpu_B_loss_gvs = [] gpu_C_loss_gvs = [] gpu_D_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_pitches_vector = self.pitches_vector[ batch_slice, :] gpu_yaws_vector = self.yaws_vector[batch_slice, :] # process model tensors gpu_src_enc_code = self.encoder(gpu_warped_src) gpu_dst_enc_code = self.encoder(gpu_target_dst) gpu_src_bT_code = self.bT(gpu_src_enc_code) gpu_src_bT_code_ng = tf.stop_gradient(gpu_src_bT_code) gpu_src_T_pitch, gpu_src_T_yaw = self.bTC(gpu_src_bT_code) gpu_dst_bT_code = self.bT(gpu_dst_enc_code) gpu_src_bP_code = self.bP(gpu_src_enc_code) gpu_src_P_pitch, gpu_src_P_yaw = self.bPC(gpu_src_bP_code) def crossentropy(target, output): output = tf.nn.softmax(output) output = tf.clip_by_value(output, 1e-7, 1 - 1e-7) return tf.reduce_sum(target * -tf.log(output), axis=-1, keepdims=False) def negative_crossentropy(n_classes, output): output = tf.nn.softmax(output) output = tf.clip_by_value(output, 1e-7, 1 - 1e-7) return (1.0 / n_classes) * tf.reduce_sum( tf.log(output), axis=-1, keepdims=False) gpu_src_bT_code_n = gpu_src_bT_code_ng + tf.random.normal( tf.shape(gpu_src_bT_code_ng)) gpu_src_bP_code_n = gpu_src_bP_code + tf.random.normal( tf.shape(gpu_src_bP_code)) gpu_pred_src = self.decoder( self.inter( tf.concat([gpu_src_bT_code_ng, gpu_src_bP_code], axis=-1))) gpu_pred_src_n = self.decoder( self.inter( tf.concat([gpu_src_bT_code_n, gpu_src_bP_code_n], axis=-1))) gpu_pred_dst = self.decoder( self.inter( tf.concat([gpu_dst_bT_code, gpu_src_bP_code], axis=-1))) gpu_A_loss = 1.0*crossentropy(gpu_pitches_vector, gpu_src_T_pitch ) + \ 1.0*crossentropy(gpu_yaws_vector, gpu_src_T_yaw ) gpu_B_loss = 0.1*crossentropy(gpu_pitches_vector, gpu_src_P_pitch ) + \ 0.1*crossentropy(gpu_yaws_vector, gpu_src_P_yaw ) gpu_C_loss = 0.1*negative_crossentropy( n_pyr_classes, gpu_src_P_pitch ) + \ 0.1*negative_crossentropy( n_pyr_classes, gpu_src_P_yaw ) gpu_D_loss = 0.0000001*(\ 0.5*tf.reduce_sum(tf.square(gpu_target_src-gpu_pred_src), axis=[1,2,3]) + \ 0.5*tf.reduce_sum(tf.square(gpu_target_src-gpu_pred_src_n), axis=[1,2,3]) ) gpu_pred_src_list.append(gpu_pred_src) gpu_pred_dst_list.append(gpu_pred_dst) gpu_A_losses += [gpu_A_loss] gpu_B_losses += [gpu_B_loss] gpu_C_losses += [gpu_C_loss] gpu_D_losses += [gpu_D_loss] A_weights = self.encoder.get_weights( ) + self.bT.get_weights() + self.bTC.get_weights() B_weights = self.bPC.get_weights() C_weights = self.encoder.get_weights( ) + self.bP.get_weights() D_weights = self.inter.get_weights( ) + self.decoder.get_weights() gpu_A_loss_gvs += [nn.gradients(gpu_A_loss, A_weights)] gpu_B_loss_gvs += [nn.gradients(gpu_B_loss, B_weights)] gpu_C_loss_gvs += [nn.gradients(gpu_C_loss, C_weights)] gpu_D_loss_gvs += [nn.gradients(gpu_D_loss, D_weights)] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src = nn.concat(gpu_pred_src_list, 0) pred_dst = nn.concat(gpu_pred_dst_list, 0) A_loss = nn.average_tensor_list(gpu_A_losses) B_loss = nn.average_tensor_list(gpu_B_losses) C_loss = nn.average_tensor_list(gpu_C_losses) D_loss = nn.average_tensor_list(gpu_D_losses) A_loss_gv = nn.average_gv_list(gpu_A_loss_gvs) B_loss_gv = nn.average_gv_list(gpu_B_loss_gvs) C_loss_gv = nn.average_gv_list(gpu_C_loss_gvs) D_loss_gv = nn.average_gv_list(gpu_D_loss_gvs) A_loss_gv_op = self.src_dst_opt.get_update_op(A_loss_gv) B_loss_gv_op = self.src_dst_opt.get_update_op(B_loss_gv) C_loss_gv_op = self.src_dst_opt.get_update_op(C_loss_gv) D_loss_gv_op = self.src_dst_opt.get_update_op(D_loss_gv) # Initializing training and view functions def A_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [A_loss, A_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.A_train = A_train def B_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [B_loss, B_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.B_train = B_train def C_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [C_loss, C_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.C_train = C_train def D_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [D_loss, D_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.D_train = D_train def AE_view(warped_src): return nn.tf_sess.run([pred_src], feed_dict={self.warped_src: warped_src}) self.AE_view = AE_view def AE_view2(warped_src, target_dst): return nn.tf_sess.run([pred_dst], feed_dict={ self.warped_src: warped_src, self.target_dst: target_dst }) self.AE_view2 = AE_view2 else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init and self.pretrained_model_path is not None: pretrained_filepath = self.pretrained_model_path / filename if pretrained_filepath.exists(): do_init = not model.load_weights(pretrained_filepath) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.PITCH_YAW_ROLL_SIGMOID, 'resolution': resolution }, ], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.PITCH_YAW_ROLL_SIGMOID, 'resolution': resolution }, ], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len( device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 256 self.face_type = FaceType.WHOLE_FACE place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' bgr_shape = nn.get4Dshape(resolution, resolution, 3) mask_shape = nn.get4Dshape(resolution, resolution, 1) # Initializing model classes self.model = XSegNet(name=f'XSeg', resolution=resolution, load_weights=not self.is_first_run(), weights_file_root=self.get_model_root_path(), training=True, place_model_on_cpu=place_model_on_cpu, optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), data_format=nn.data_format) if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_list = [] gpu_losses = [] gpu_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_input_t = self.model.input_t[batch_slice, :, :, :] gpu_target_t = self.model.target_t[ batch_slice, :, :, :] # process model tensors gpu_pred_logits_t, gpu_pred_t = self.model.flow( gpu_input_t) gpu_pred_list.append(gpu_pred_t) gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1, 2, 3]) gpu_losses += [gpu_loss] gpu_loss_gvs += [ nn.gradients(gpu_loss, self.model.get_weights()) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred = nn.concat(gpu_pred_list, 0) loss = tf.reduce_mean(gpu_losses) loss_gv_op = self.model.opt.get_update_op( nn.average_gv_list(gpu_loss_gvs)) # Initializing training and view functions def train(input_np, target_np): l, _ = nn.tf_sess.run([loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l self.train = train def view(input_np): return nn.tf_sess.run([pred], feed_dict={self.model.input_t: input_np}) self.view = view # initializing sample generators cpu_count = min(multiprocessing.cpu_count(), 8) src_dst_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 srcdst_generator = SampleGeneratorFaceXSeg( [self.training_data_src_path, self.training_data_dst_path], debug=self.is_debug(), batch_size=self.get_batch_size(), resolution=resolution, face_type=self.face_type, generators_count=src_dst_generators_count, data_format=nn.data_format) src_generator = SampleGeneratorFace( self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count, raise_on_no_data=False) dst_generator = SampleGeneratorFace( self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count, raise_on_no_data=False) self.set_training_data_generators( [srcdst_generator, src_generator, dst_generator])
def on_initialize(self): nn.initialize() tf = nn.tf nn.set_floatx(tf.float32) conv_kernel_initializer = nn.initializers.ca class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs ): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer ) def forward(self, x): x = self.conv1(x) if self.subpixel: x = tf.nn.space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch*( min(2**i, 8) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = x = tf.nn.leaky_relu(x, 0.1) x = tf.nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = x = tf.nn.leaky_relu(x, 0.1) x = self.conv2(x) x = inp + x x = x = tf.nn.leaky_relu(x, 0.1) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5) def forward(self, inp): return nn.flatten(self.down1(inp)) class Inter(nn.ModelBase): def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch, **kwargs): self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch super().__init__(**kwargs) def on_build(self): in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal ) self.upscale1 = Upscale(ae_out_ch, d_ch*8) self.res1 = ResidualBlock(d_ch*8) def forward(self, inp): x = self.dense1(inp) x = self.dense2(x) x = tf.reshape (x, (-1, lowest_dense_res, lowest_dense_res, self.ae_out_ch)) x = self.upscale1(x) x = self.res1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch): self.upscale0_1 = Upscale(in_ch, d_ch*1) self.upscale0_2 = Upscale(d_ch*1, d_ch*1) self.upscale0_3 = Upscale(d_ch*1, d_ch*1) self.out0 = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.upscale1_1 = Upscale(in_ch, d_ch*1) self.upscale1_2 = Upscale(d_ch*2, d_ch*1) self.upscale1_3 = Upscale(d_ch*2, d_ch*1) self.out1 = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.upscalem1 = Upscale(in_ch, d_ch) self.upscalem2 = Upscale(d_ch, d_ch//2) self.upscalem3 = Upscale(d_ch//2, d_ch//2) self.outm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) def get_weights_ex(self, stage): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = self.upscalem1.get_weights() \ + self.upscalem2.get_weights() \ + self.upscalem3.get_weights() \ + self.outm.get_weights() if stage >= 0: weights += self.upscale0_1.get_weights() \ + self.upscale0_2.get_weights() \ + self.upscale0_3.get_weights() \ + self.out0.get_weights() if stage >= 1: weights += self.upscale1_1.get_weights() \ + self.upscale1_2.get_weights() \ + self.upscale1_3.get_weights() \ + self.out1.get_weights() return weights def forward(self, inp, stage=0): z = inp x0_1 = self.upscale0_1 (z) x0_2 = self.upscale0_2 (x0_1) x0_3 = self.upscale0_3 (x0_2) x = tf.nn.tanh(self.out0(x0_3)) if stage >= 1: x1_1 = self.upscale1_1 (z) x1_2 = self.upscale1_2 ( tf.concat([x0_1, x1_1],-1) ) x1_3 = self.upscale1_3 ( tf.concat([x0_2, x1_2],-1) ) x1 = tf.nn.tanh(self.out1(tf.concat([x0_3, x1_3],-1))) x = x+x1 y = self.upscalem1 (z) y = self.upscalem2 (y) y = self.upscalem3 (y) return x / 2 + 0.5, \ tf.nn.sigmoid(self.outm(y)) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.stage = stage = self.options['stage'] self.start_stage_iter = self.options.get('start_stage_iter', 0) self.target_stage_iter = self.options.get('target_stage_iter', 0) resolution = self.resolution = 192 ae_dims = 128 e_dims = 128 d_dims = 64 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4 models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_nc = 3 output_nc = 3 bgr_shape = (resolution, resolution, output_nc) mask_shape = (resolution, resolution, 1) lowest_dense_res = resolution // 16 self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.warped_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_srcm = tf.placeholder (tf.float32, (None,)+mask_shape) self.target_dstm = tf.placeholder (tf.float32, (None,)+mask_shape) # Initializing model classes with tf.device (models_opt_device): self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1] self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') inter_out_ch = self.inter.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1] self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: self.src_dst_all_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() \ + self.decoder_src.get_weights_ex(stage) \ + self.decoder_dst.get_weights_ex(stage) # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt.initialize_variables(self.src_dst_all_weights, vars_on_cpu=optimizer_vars_on_cpu ) self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code, stage=stage) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code, stage=stage) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) gpu_target_srcmasked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( tf.square( gpu_target_srcm - gpu_pred_src_srcm ), axis=[1,2,3] ) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients ( gpu_src_dst_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): if gpu_count == 1: pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = gpu_pred_src_dstm_list[0] src_loss = gpu_src_losses[0] dst_loss = gpu_dst_losses[0] src_dst_loss_gv = gpu_src_dst_loss_gvs[0] else: pred_src_src = tf.concat(gpu_pred_src_src_list, 0) pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list (gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code, stage=stage) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.pretrain_just_disabled: if model == self.inter: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count - src_generators_count self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution, }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution } ], generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution} ], generators_count=dst_generators_count ) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NHWC"#"NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf self.resolution = resolution = self.options['resolution'] self.face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] models_opt_on_gpu = True#False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] class BaseModel(nn.ModelBase): def on_build(self, in_ch, base_ch, out_ch=None): self.convs = [ nn.Conv2D( in_ch, base_ch, kernel_size=7, strides=1, padding='SAME'), nn.Conv2D( base_ch, base_ch, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch, base_ch*2, kernel_size=3, strides=2, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch*2, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch*4, kernel_size=3, strides=2, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*8, kernel_size=3, strides=2, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*8, base_ch*8, kernel_size=3, strides=1, use_bias=False, padding='SAME') ] self.frns = [ None, nn.FRNorm2D(base_ch), nn.FRNorm2D(base_ch*2), nn.FRNorm2D(base_ch*2), nn.FRNorm2D(base_ch*4), nn.FRNorm2D(base_ch*4), nn.FRNorm2D(base_ch*8), nn.FRNorm2D(base_ch*8), ] self.tlus = [ nn.TLU(base_ch), nn.TLU(base_ch), nn.TLU(base_ch*2), nn.TLU(base_ch*2), nn.TLU(base_ch*4), nn.TLU(base_ch*4), nn.TLU(base_ch*8), nn.TLU(base_ch*8), ] if out_ch is not None: self.out_conv = nn.Conv2D( base_ch*8, out_ch, kernel_size=1, strides=1, use_bias=False, padding='VALID') else: self.out_conv = None def forward(self, inp): x = inp for i in range(len(self.convs)): x = self.convs[i](x) if self.frns[i] is not None: x = self.frns[i](x) x = self.tlus[i](x) if self.out_conv is not None: x = self.out_conv(x) return x class Regressor(nn.ModelBase): def on_build(self, lmrks_ch, base_ch, out_ch): self.convs = [ nn.Conv2D( base_ch*8+lmrks_ch, base_ch*8, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*8, base_ch*8*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*8, base_ch*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*4*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*2, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch*2*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch, kernel_size=3, strides=1, use_bias=False, padding='SAME'), ] self.frns = [ nn.FRNorm2D(base_ch*8), nn.FRNorm2D(base_ch*8*4), nn.FRNorm2D(base_ch*4), nn.FRNorm2D(base_ch*4*4), nn.FRNorm2D(base_ch*2), nn.FRNorm2D(base_ch*2*4), nn.FRNorm2D(base_ch), ] self.tlus = [ nn.TLU(base_ch*8), nn.TLU(base_ch*8*4), nn.TLU(base_ch*4), nn.TLU(base_ch*4*4), nn.TLU(base_ch*2), nn.TLU(base_ch*2*4), nn.TLU(base_ch), ] self.use_upscale = [ False, True, False, True, False, True, False, ] self.out_conv = nn.Conv2D( base_ch, out_ch, kernel_size=3, strides=1, padding='SAME') def forward(self, inp): x = inp for i in range(len(self.convs)): x = self.convs[i](x) x = self.frns[i](x) x = self.tlus[i](x) if self.use_upscale[i]: x = nn.depth_to_space(x, 2) x = self.out_conv(x) x = tf.nn.sigmoid(x) return x def get_coord(x, other_axis, axis_size): # get "x-y" coordinates: g_c_prob = tf.reduce_mean(x, axis=other_axis) # B,W,NMAP g_c_prob = tf.nn.softmax(g_c_prob, axis=1) # B,W,NMAP coord_pt = tf.to_float(tf.linspace(-1.0, 1.0, axis_size)) # W coord_pt = tf.reshape(coord_pt, [1, axis_size, 1]) g_c = tf.reduce_sum(g_c_prob * coord_pt, axis=1) return g_c, g_c_prob def get_gaussian_maps(mu_x, mu_y, width, height, inv_std=10.0, mode='rot'): """ Generates [B,SHAPE_H,SHAPE_W,NMAPS] tensor of 2D gaussians, given the gaussian centers: MU [B, NMAPS, 2] tensor. STD: is the fixed standard dev. """ y = tf.to_float(tf.linspace(-1.0, 1.0, width)) x = tf.to_float(tf.linspace(-1.0, 1.0, height)) if mode in ['rot', 'flat']: mu_y, mu_x = mu_y[...,None,None], mu_x[...,None,None] y = tf.reshape(y, [1, 1, width, 1]) x = tf.reshape(x, [1, 1, 1, height]) g_y = tf.square(y - mu_y) g_x = tf.square(x - mu_x) dist = (g_y + g_x) * inv_std**2 if mode == 'rot': g_yx = tf.exp(-dist) else: g_yx = tf.exp(-tf.pow(dist + 1e-5, 0.25)) elif mode == 'ankush': y = tf.reshape(y, [1, 1, width]) x = tf.reshape(x, [1, 1, height]) g_y = tf.exp(-tf.sqrt(1e-4 + tf.abs((mu_y[...,None] - y) * inv_std))) g_x = tf.exp(-tf.sqrt(1e-4 + tf.abs((mu_x[...,None] - x) * inv_std))) g_y = tf.expand_dims(g_y, axis=3) g_x = tf.expand_dims(g_x, axis=2) g_yx = tf.matmul(g_y, g_x) # [B, NMAPS, H, W] else: raise ValueError('Unknown mode: ' + str(mode)) g_yx = tf.transpose(g_yx, perm=[0, 2, 3, 1]) return g_yx with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape) self.target_src = tf.placeholder (nn.floatx, bgr_shape) # Initializing model classes #model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None) self.landmarks_count = 512 self.n_ch = 32 with tf.device (models_opt_device): self.detector = BaseModel(3, self.n_ch, out_ch=self.landmarks_count, name='Detector') self.extractor = BaseModel(3, self.n_ch, name='Extractor') self.regressor = Regressor(self.landmarks_count, self.n_ch, 3, name='Regressor') self.model_filename_list += [ [self.detector, 'detector.npy'], [self.extractor, 'extractor.npy'], [self.regressor, 'regressor.npy'] ] if self.is_training: # Initialize optimizers lr=5e-5 lr_dropout = 0.3#0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0 clipnorm = 0.0#1.0 if self.options['clipgrad'] else 0.0 self.model_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='model_opt') self.model_filename_list += [ (self.model_opt, 'model_opt.npy') ] self.model_trainable_weights = self.detector.get_weights() + self.extractor.get_weights() + self.regressor.get_weights() self.model_opt.initialize_variables (self.model_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_src_rec_list = [] gauss_mu_list = [] gpu_src_losses = [] gpu_G_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] # process model tensors gpu_src_feat = self.extractor(gpu_warped_src) gpu_src_heatmaps = self.detector(gpu_target_src) gauss_y, gauss_y_prob = get_coord(gpu_src_heatmaps, 2, gpu_src_heatmaps.shape.as_list()[1] ) gauss_x, gauss_x_prob = get_coord(gpu_src_heatmaps, 1, gpu_src_heatmaps.shape.as_list()[2] ) gauss_mu = tf.stack ( (gauss_x, gauss_y), -1) dist_loss = [] for i in range(self.landmarks_count): t = tf.concat( (gauss_mu[:,0:i], gauss_mu[:,i+1:] ), axis=1 ) diff = t - gauss_mu[:,i:i+1] dist = tf.sqrt( diff[...,0]**2+diff[...,1]**2 ) dist_loss += [ tf.reduce_mean(2.0 - dist,-1) ] dist_loss = sum(dist_loss) / self.landmarks_count #import code #code.interact(local=dict(globals(), **locals())) gauss_xy = get_gaussian_maps ( gauss_x, gauss_y, 16, 16 ) gpu_src_rec = self.regressor( tf.concat ( (gpu_src_feat, gauss_xy), -1) ) gpu_src_rec_list.append(gpu_src_rec) gauss_mu_list.append(gauss_mu) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src, gpu_src_rec, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square (gpu_target_src - gpu_src_rec), axis=[1,2,3]) gpu_src_loss += dist_loss gpu_src_losses += [gpu_src_loss] gpu_G_loss_gvs += [ nn.gradients ( gpu_src_loss, self.model_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): src_rec = nn.concat(gpu_src_rec_list, 0) gauss_mu = nn.concat(gauss_mu_list, 0) src_loss = tf.concat(gpu_src_losses, 0) loss_gv_op = self.model_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) # Initializing training and view functions def ae_train(warped_src, target_src): s, _ = nn.tf_sess.run ( [ src_loss, loss_gv_op], feed_dict={self.warped_src:warped_src, self.target_src:target_src}) return s self.ae_train = ae_train def AE_view(warped_src, target_src): return nn.tf_sess.run ( [src_rec, gauss_mu], feed_dict={self.warped_src:warped_src, self.target_src:target_src}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): if 'df' in archi: gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size()*2, sample_process_options=SampleProcessor.Options(random_flip=False), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], generators_count=src_generators_count ), ])