def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" nn.initialize(data_format=self.model_data_format) tf = nn.tf self.resolution = resolution = self.options['resolution'] lowest_dense_res = self.lowest_dense_res = resolution // 32 class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size super().__init__(*kwargs) def on_build(self, *args, **kwargs ): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME') def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return self.out_ch class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp+x, 0.2) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch, ae_ch): self.down1 = Downscale(in_ch, e_ch, kernel_size=5) self.res1 = ResidualBlock(e_ch) self.down2 = Downscale(e_ch, e_ch*2, kernel_size=5) self.down3 = Downscale(e_ch*2, e_ch*4, kernel_size=5) self.down4 = Downscale(e_ch*4, e_ch*8, kernel_size=5) self.down5 = Downscale(e_ch*8, e_ch*8, kernel_size=5) self.res5 = ResidualBlock(e_ch*8) self.dense1 = nn.Dense( lowest_dense_res*lowest_dense_res*e_ch*8, ae_ch ) def forward(self, inp): x = inp x = self.down1(x) x = self.res1(x) x = self.down2(x) x = self.down3(x) x = self.down4(x) x = self.down5(x) x = self.res5(x) x = nn.flatten(x) x = nn.pixel_norm(x, axes=-1) x = self.dense1(x) return x class Inter(nn.ModelBase): def __init__(self, ae_ch, ae_out_ch, **kwargs): self.ae_ch, self.ae_out_ch = ae_ch, ae_out_ch super().__init__(**kwargs) def on_build(self): ae_ch, ae_out_ch = self.ae_ch, self.ae_out_ch self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) def forward(self, inp): x = inp x = self.dense2(x) x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch, d_mask_ch ): self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3) self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res1 = ResidualBlock(d_ch*8, kernel_size=3) self.res2 = ResidualBlock(d_ch*4, kernel_size=3) self.res3 = ResidualBlock(d_ch*2, kernel_size=3) self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME') self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') def forward(self, inp): z = inp x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) x = self.upscale3(x) x = self.res3(x) x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), self.out_conv1(x), self.out_conv2(x), self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) m = self.upscalem3(m) m = self.upscalem4(m) m = tf.nn.sigmoid(self.out_convm(m)) return x, m self.face_type = {'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] if 'eyes_prio' in self.options: self.options.pop('eyes_prio') eyes_mouth_prio = self.options['eyes_mouth_prio'] ae_dims = self.ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] morph_factor = self.options['morph_factor'] pretrain = self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] random_warp = False if self.pretrain else self.options['random_warp'] random_src_flip = self.random_src_flip if not self.pretrain else True random_dst_flip = self.random_dst_flip if not self.pretrain else True if self.pretrain: self.options_show_override['gan_power'] = 0.0 self.options_show_override['random_warp'] = False self.options_show_override['lr_dropout'] = 'n' self.options_show_override['uniform_yaw'] = True masked_training = self.options['masked_training'] ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') # Initializing model classes with tf.device (models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, ae_ch=ae_dims, name='encoder') self.inter_src = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_src') self.inter_dst = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_dst') self.decoder = Decoder(in_ch=ae_dims, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_src, 'inter_src.npy'], [self.inter_dst , 'inter_dst.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: if gan_power != 0: self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") self.model_filename_list += [ [self.GAN, 'GAN.npy'] ] # Initialize optimizers lr=5e-5 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.all_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() if pretrain: self.trainable_weights = self.encoder.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() else: self.trainable_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() self.src_dst_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if gan_power != 0: self.GAN_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() self.model_filename_list += [ (self.GAN_opt, 'GAN_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_GAN_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] # process model tensors gpu_src_code = self.encoder (gpu_warped_src) gpu_dst_code = self.encoder (gpu_warped_dst) if pretrain: gpu_src_inter_src_code = self.inter_src (gpu_src_code) gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) gpu_src_code = gpu_src_inter_src_code * nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor) gpu_dst_code = gpu_src_dst_code = gpu_dst_inter_dst_code * nn.random_binomial( [bs_per_gpu, gpu_dst_inter_dst_code.shape.as_list()[1], 1,1] , p=0.25) else: gpu_src_inter_src_code = self.inter_src (gpu_src_code) gpu_src_inter_dst_code = self.inter_dst (gpu_src_code) gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) inter_rnd_binomial = nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor) gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]), tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 gpu_target_dst_anti_masked = gpu_target_dst*(1.0-gpu_target_dstm_blur) gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst*gpu_target_dstm_blur if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur) gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur) if resolution < 256: gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) else: gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) if eyes_mouth_prio: gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_dst_loss += 0.1*tf.reduce_mean(tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) gpu_dst_losses += [gpu_dst_loss] if not pretrain: if resolution < 256: gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) else: gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_mouth_prio: gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) else: gpu_src_loss = gpu_dst_loss gpu_src_losses += [gpu_src_loss] if pretrain: gpu_G_loss = gpu_dst_loss else: gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLossOnes(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) def DLossZeros(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) if gan_power != 0: gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked_opt) gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked_opt) gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked_opt) gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked_opt) gpu_D_src_dst_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) ) * ( 1.0 / 8) gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.GAN.get_weights() ) ] gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) ) * gan_power if masked_training: # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device(f'/CPU:0'): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) with tf.device (models_opt_device): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) #GAN_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gvs) ) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em, }) return s, d self.src_dst_train = src_dst_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em}) self.D_src_dst_train = D_src_dst_train def AE_view(warped_src, warped_dst, morph_value): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_view = AE_view else: #Initializing merge function with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]), tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) def AE_merge(warped_dst, morph_value): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter_src or model == self.inter_dst: do_init = True else: do_init = self.is_first_run() if self.is_training and gan_power != 0 and model == self.GAN: if self.gan_model_changed: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() ############### # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=random_src_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True)
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" nn.initialize(data_format=self.model_data_format) tf = nn.tf input_ch=3 resolution = self.resolution = self.options['resolution'] e_dims = self.options['e_dims'] ae_dims = self.options['ae_dims'] inter_dims = self.inter_dims = self.options['inter_dims'] inter_res = self.inter_res = resolution // 32 d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] face_type = self.face_type = {'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] morph_factor = self.options['morph_factor'] gan_power = self.gan_power = self.options['gan_power'] random_warp = self.options['random_warp'] blur_out_mask = self.options['blur_out_mask'] ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None use_fp16 = False if self.is_exporting: use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') conv_dtype = tf.float16 if use_fp16 else tf.float32 class Downscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=5 ): self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) def forward(self, x): return tf.nn.leaky_relu(self.conv1(x), 0.1) class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, x): x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp+x, 0.2) return x class Encoder(nn.ModelBase): def on_build(self): self.down1 = Downscale(input_ch, e_dims, kernel_size=5) self.res1 = ResidualBlock(e_dims) self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) self.res5 = ResidualBlock(e_dims*8) self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) def forward(self, x): if use_fp16: x = tf.cast(x, tf.float16) x = self.down1(x) x = self.res1(x) x = self.down2(x) x = self.down3(x) x = self.down4(x) x = self.down5(x) x = self.res5(x) if use_fp16: x = tf.cast(x, tf.float32) x = nn.pixel_norm(nn.flatten(x), axes=-1) x = self.dense1(x) return x class Inter(nn.ModelBase): def on_build(self): self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) def forward(self, inp): x = inp x = self.dense2(x) x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) return x class Decoder(nn.ModelBase): def on_build(self ): self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) self.res0 = ResidualBlock(d_dims*8, kernel_size=3) self.res1 = ResidualBlock(d_dims*8, kernel_size=3) self.res2 = ResidualBlock(d_dims*4, kernel_size=3) self.res3 = ResidualBlock(d_dims*2, kernel_size=3) self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) def forward(self, z): if use_fp16: z = tf.cast(z, tf.float16) x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) x = self.upscale3(x) x = self.res3(x) x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), self.out_conv1(x), self.out_conv2(x), self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) m = self.upscalem3(m) m = self.upscalem4(m) m = tf.nn.sigmoid(self.out_convm(m)) if use_fp16: x = tf.cast(x, tf.float32) m = tf.cast(m, tf.float32) return x, m models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') # Initializing model classes with tf.device (models_opt_device): self.encoder = Encoder(name='encoder') self.inter_src = Inter(name='inter_src') self.inter_dst = Inter(name='inter_dst') self.decoder = Decoder(name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_src, 'inter_src.npy'], [self.inter_dst , 'inter_dst.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: # Initialize optimizers clipnorm = 1.0 if self.options['clipgrad'] else 0.0 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0 self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() #if random_warp: # self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights() self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if gan_power != 0: self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ [self.GAN, 'GAN.npy'], [self.GAN_opt, 'GAN_opt.npy'] ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gradients = [] gpu_GAN_loss_gradients = [] def DLossOnes(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) def DLossZeros(logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) for gpu_id in range(gpu_count): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] # process model tensors gpu_src_code = self.encoder (gpu_warped_src) gpu_dst_code = self.encoder (gpu_warped_dst) gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) inter_dims_bin = int(inter_dims*morph_factor) with tf.device(f'/CPU:0'): inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_anti = 1-gpu_target_srcm gpu_target_dstm_anti = 1-gpu_target_dstm gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur if blur_out_mask: sigma = resolution / 128 x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur # Structural loss gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) # Pixel loss gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3]) # Eyes+mouth prio loss gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3]) # Mask loss gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss # dst-dst background weak loss gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) if gan_power != 0: gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked) gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked) gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) ) * (1.0 / 8) gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) ) * gan_power # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device(f'/CPU:0'): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) with tf.device (models_opt_device): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) if gan_power != 0: GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) # Initializing training and view functions def train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em, }) return s, d self.train = train if gan_power != 0: def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em}) self.GAN_train = GAN_train def AE_view(warped_src, warped_dst, morph_value): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_view = AE_view else: #Initializing merge function with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) def AE_merge(warped_dst, morph_value): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.is_training and gan_power != 0 and model == self.GAN: if self.gan_model_changed: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() ############### # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, generators_count=dst_generators_count ) ]) self.last_src_samples_loss = [] self.last_dst_samples_loss = []