def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf self.resolution = resolution = self.options['resolution'] self.face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] eyes_prio = self.options['eyes_prio'] archi_split = self.options['archi'].split('-') if len(archi_split) == 2: archi_type, archi_opts = archi_split elif len(archi_split) == 1: archi_type, archi_opts = archi_split[0], None ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] random_warp = False if self.pretrain else self.options['random_warp'] if self.pretrain: self.options_show_override['gan_power'] = 0.0 self.options_show_override['random_warp'] = False self.options_show_override['lr_dropout'] = 'n' self.options_show_override['face_style_power'] = 0.0 self.options_show_override['bg_style_power'] = 0.0 self.options_show_override['uniform_yaw'] = True masked_training = self.options['masked_training'] import dfl dfl.load_config() masked_training = dfl.get_config("masked_training", "1") == "1" ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape) self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) self.target_src = tf.placeholder (nn.floatx, bgr_shape) self.target_dst = tf.placeholder (nn.floatx, bgr_shape) self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape) self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape) # Initializing model classes model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts) with tf.device (models_opt_device): if 'df' in archi_type: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: if self.options['true_face_power'] != 0: self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' ) self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] elif 'liae' in archi_type: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inters_out_ch = inter_AB_out_ch+inter_B_out_ch self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_AB, 'inter_AB.npy'], [self.inter_B , 'inter_B.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: if gan_power != 0: self.D_src = nn.UNetPatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src") self.model_filename_list += [ [self.D_src, 'D_src_v2.npy'] ] # Initialize optimizers lr=5e-5 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 if 'df' in archi_type: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() elif 'liae' in archi_type: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if self.options['true_face_power'] != 0: self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] if gan_power != 0: self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_v2_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:] gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:] # process model tensors if 'df' in archi_type: gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) elif 'liae' in archi_type: gpu_src_code = self.encoder (gpu_warped_src) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) gpu_dst_code = self.encoder (gpu_warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) # unpack masks from one combined mask gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1) gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1) gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1) gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) if resolution < 256: gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) else: gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) face_style_power = self.options['face_style_power'] / 100.0 if face_style_power != 0 and not self.pretrain: gpu_src_loss += nn.style_loss(gpu_psd_target_dst_style_masked, gpu_target_dst_style_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) bg_style_power = self.options['bg_style_power'] / 100.0 if bg_style_power != 0 and not self.pretrain: gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) if resolution < 256: gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) else: gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLoss(labels,logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) if self.options['true_face_power'] != 0: gpu_src_code_d = self.code_discriminator( gpu_src_code ) gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d) gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \ DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] if gan_power != 0: gpu_pred_src_src_d, \ gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt) gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2) gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2) gpu_target_src_d, \ gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2) gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \ (DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \ DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5 gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights() gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2)) gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) if self.options['true_face_power'] != 0: D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all, }) return s, d self.src_dst_train = src_dst_train if self.options['true_face_power'] != 0: def D_train(warped_src, warped_dst): nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) self.D_train = D_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all}) self.D_src_dst_train = D_src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): if 'df' in archi_type: gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) elif 'liae' in archi_type: gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if 'df' in archi_type: if model == self.inter: do_init = True elif 'liae' in archi_type: if model == self.inter_AB or model == self.inter_B: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True)
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf self.resolution = resolution = self.options['resolution'] self.face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ] eyes_prio = self.options['eyes_prio'] archi = self.options['archi'] is_hd = 'hd' in archi ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0 masked_training = self.options['masked_training'] ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape) self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) self.src_code_in = tf.placeholder (nn.floatx, (None,256) ) self.target_src = tf.placeholder (nn.floatx, bgr_shape) self.target_dst = tf.placeholder (nn.floatx, bgr_shape) self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape) self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape) # Initializing model classes model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None) with tf.device (models_opt_device): if 'df' in archi: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter') inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: if self.options['true_face_power'] != 0: self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' ) self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] elif 'liae' in archi: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B') inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inters_out_ch = inter_AB_out_ch+inter_B_out_ch self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_AB, 'inter_AB.npy'], [self.inter_B , 'inter_B.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: if gan_power != 0: self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src") self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst") self.model_filename_list += [ [self.D_src, 'D_src.npy'] ] self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ] # Initialize optimizers lr=5e-5 lr_dropout = 0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if 'df' in archi: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() elif 'liae' in archi: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) if self.options['true_face_power'] != 0: self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] if gan_power != 0: self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_src_latent_code_list = [] gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_src_code_in = self.src_code_in[batch_slice,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:] gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:] # process model tensors if 'df' in archi: gpu_src_latent_code = self.inter.dense1(self.encoder(gpu_warped_src)) gpu_src_in_code = self.inter.fd(gpu_src_code_in) gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_in_code) #gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) elif 'liae' in archi: gpu_src_code = self.encoder (gpu_warped_src) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) gpu_dst_code = self.encoder (gpu_warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_src_latent_code_list.append(gpu_src_latent_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) # unpack masks from one combined mask gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1) gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1) gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1) gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) face_style_power = self.options['face_style_power'] / 100.0 if face_style_power != 0 and not self.pretrain: gpu_src_loss += nn.style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) bg_style_power = self.options['bg_style_power'] / 100.0 if bg_style_power != 0 and not self.pretrain: gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] ) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) if eyes_prio: gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLoss(labels,logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) if self.options['true_face_power'] != 0: gpu_src_code_d = self.code_discriminator( gpu_src_code ) gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d) gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \ DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] if gan_power != 0: gpu_pred_src_src_d = self.D_src(gpu_pred_src_src_masked_opt) gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) gpu_target_src_d = self.D_src(gpu_target_src_masked_opt) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_pred_dst_dst_d = self.D_dst(gpu_pred_dst_dst_masked_opt) gpu_pred_dst_dst_d_ones = tf.ones_like (gpu_pred_dst_dst_d) gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d) gpu_target_dst_d = self.D_dst(gpu_target_dst_masked_opt) gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d) gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \ (DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \ DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5 gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ] gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d)) gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): src_latent_code = nn.concat(gpu_src_latent_code_list, 0) pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) if self.options['true_face_power'] != 0: D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all, }) return s, d self.src_dst_train = src_dst_train if self.options['true_face_power'] != 0: def D_train(warped_src, warped_dst): nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) self.D_train = D_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm_all:target_srcm_all, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm_all:target_dstm_all}) self.D_src_dst_train = D_src_dst_train def AE_get_latent(warped_src): return nn.tf_sess.run ( src_latent_code, feed_dict={self.warped_src:warped_src}) self.AE_get_latent = AE_get_latent def AE_view_src(warped_src, src_code_in): return nn.tf_sess.run ( pred_src_src, feed_dict={self.warped_src:warped_src, self.src_code_in:src_code_in }) self.AE_view_src = AE_view_src def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): if 'df' in archi: gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) elif 'liae' in archi: gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if 'df' in archi: if model == self.inter: do_init = True elif 'liae' in archi: if model == self.inter_AB: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.PITCH_YAW_ROLL, 'resolution': resolution}, ], generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True) class PRD(nn.ModelBase): def on_build(self, ae_ch): self.dense1 = nn.Dense( ae_ch+1, 1024 ) self.dense2 = nn.Dense( 1024, 2048 ) self.dense3 = nn.Dense( 2048, 4096 ) self.dense4 = nn.Dense( 4096, 4096 ) self.dense5 = nn.Dense( 4096, ae_ch ) def forward(self, inp, yaw_in): x = tf.concat( [inp, yaw_in], -1 ) x = self.dense1(x) x = self.dense2(x) x = self.dense3(x) x = self.dense4(x) x = self.dense5(x) return x with tf.device( f'/GPU:0'): prd_model = PRD(256, name='PRD') prd_model.init_weights() prd_in = tf.placeholder (nn.floatx, (None,256) ) prd_targ = tf.placeholder (nn.floatx, (None,256) ) yaw_diff_in = tf.placeholder (nn.floatx, (None,1) ) prd_out = prd_model(prd_in, yaw_diff_in) loss = tf.reduce_sum ( tf.abs (prd_out - prd_targ) ) loss_gvs = nn.gradients (loss, prd_model.get_weights() ) prd_opt = nn.RMSprop(lr=5e-6, lr_dropout=0.3, name='prd_opt') prd_opt.initialize_variables(prd_model.get_weights()) prd_opt.init_weights() loss_gv_op = prd_opt.get_update_op (loss_gvs) s_gen, _ = self.get_training_data_generators() bs = self.get_batch_size() for n in range(1000): warped_src, target_src, target_srcm_all, src_pyr = s_gen.generate_next() sl = self.AE_get_latent(target_src) prd_in_np = [] prd_targ_np = [] yaw_diff_in_np = [] for i in range(bs): prd_in_np += [sl[i]] j = i while j == i: j = np.random.randint(bs) prd_targ_np += [ sl[j] ] yaw_diff_in_np += [ np.float32( [ src_pyr[j][1]-src_pyr[i][1] ] ) ] prd_loss, _ = nn.tf_sess.run([loss, loss_gv_op], feed_dict={prd_in:prd_in_np, prd_targ:prd_targ_np, yaw_diff_in:yaw_diff_in_np} ) print(f'{n} loss = {prd_loss}') warped_src, target_src, target_srcm_all, src_pyr = s_gen.generate_next() sl = self.AE_get_latent(target_src) yaw_diff_in_np = np.float32( [ [-0.4] ] *bs ) new_sl = nn.tf_sess.run(prd_out, feed_dict={prd_in:sl, yaw_diff_in:yaw_diff_in_np} ) new_target_src = self.AE_view_src( target_src, new_sl ) target_src = np.clip( nn.to_data_format( target_src ,"NHWC", self.model_data_format), 0.0, 1.0) new_target_src = np.clip( nn.to_data_format( new_target_src ,"NHWC", self.model_data_format), 0.0, 1.0) for i in range(bs): screen = np.concatenate ( (target_src[i], new_target_src[i]), 1 ) cv2.imshow("", (screen*255).astype(np.uint8) ) cv2.waitKey(0) import code code.interact(local=dict(globals(), **locals()))