def on_initialize(self): device_config = nn.getCurrentDeviceConfig() nn.initialize(data_format="NHWC") tf = nn.tf device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 256 #self.options['resolution'] #self.face_type = {'h' : FaceType.HALF, # 'mf' : FaceType.MID_FULL, # 'f' : FaceType.FULL, # 'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ] self.face_type = FaceType.FULL place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' bgr_shape = nn.get4Dshape(resolution, resolution, 3) mask_shape = nn.get4Dshape(resolution, resolution, 1) # Initializing model classes self.model = TernausNet(f'{self.model_name}_FANSeg', resolution, FaceType.toString(self.face_type), load_weights=not self.is_first_run(), weights_file_root=self.get_model_root_path(), training=True, place_model_on_cpu=place_model_on_cpu) if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_list = [] gpu_losses = [] gpu_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_input_t = self.model.input_t[batch_slice, :, :, :] gpu_target_t = self.model.target_t[ batch_slice, :, :, :] # process model tensors gpu_pred_logits_t, gpu_pred_t = self.model.net( [gpu_input_t]) gpu_pred_list.append(gpu_pred_t) gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1, 2, 3]) gpu_losses += [gpu_loss] gpu_loss_gvs += [ nn.tf_gradients(gpu_loss, self.model.net_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred = nn.tf_concat(gpu_pred_list, 0) loss = tf.reduce_mean(gpu_losses) loss_gv_op = self.model.opt.get_update_op( nn.tf_average_gv_list(gpu_loss_gvs)) # Initializing training and view functions def train(input_np, target_np): l, _ = nn.tf_sess.run([loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l self.train = train def view(input_np): return nn.tf_sess.run([pred], feed_dict={self.model.input_t: input_np}) self.view = view # initializing sample generators training_data_src_path = self.training_data_src_path training_data_dst_path = self.training_data_dst_path cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 src_generators_count = int(src_generators_count * 1.5) src_generator = SampleGeneratorFace( training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode': 'lct', 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'motion_blur': (25, 5), 'gaussian_blur': (25, 5), 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count) dst_generator = SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'motion_blur': (25, 5), 'gaussian_blur': (25, 5), 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count, raise_on_no_data=False) if not dst_generator.is_initialized(): io.log_info( f"\nTo view the model on unseen faces, place any aligned faces in {training_data_dst_path}.\n" ) self.set_training_data_generators([src_generator, dst_generator])
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf conv_kernel_initializer = nn.initializers.ca() class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs ): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer ) def forward(self, x): x = self.conv1(x) if self.subpixel: x = nn.tf_space_to_depth(x, 2) if self.use_activator: x = nn.tf_gelu(x) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch*( min(2**i, 8) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = nn.tf_gelu(x) x = nn.tf_depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = nn.tf_gelu(x) x = self.conv2(x) x = inp + x x = nn.tf_gelu(x) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5) def forward(self, inp): return nn.tf_flatten(self.down1(inp)) class Inter(nn.ModelBase): def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch, **kwargs): self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch super().__init__(**kwargs) def on_build(self): in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, maxout_features=4, kernel_initializer=tf.initializers.orthogonal ) self.upscale1 = Upscale(ae_out_ch, d_ch*8) self.res1 = ResidualBlock(d_ch*8) def forward(self, inp): x = self.dense1(inp) x = self.dense2(x) x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = self.upscale1(x) x = self.res1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch): self.upscale1 = Upscale(in_ch, d_ch*4) self.res1 = ResidualBlock(d_ch*4) self.upscale2 = Upscale(d_ch*4, d_ch*2) self.res2 = ResidualBlock(d_ch*2) self.upscale3 = Upscale(d_ch*2, d_ch*1) self.res3 = ResidualBlock(d_ch*1) self.upscalem1 = Upscale(in_ch, d_ch) self.upscalem2 = Upscale(d_ch, d_ch//2) self.upscalem3 = Upscale(d_ch//2, d_ch//2) self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): z = inp x = self.upscale1 (z) x = self.res1 (x) x = self.upscale2 (x) x = self.res2 (x) x = self.upscale3 (x) x = self.res3 (x) y = self.upscalem1 (z) y = self.upscalem2 (y) y = self.upscalem3 (y) return tf.nn.sigmoid(self.out_conv(x)), \ tf.nn.sigmoid(self.out_convm(y)) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices resolution = self.resolution = 96 ae_dims = 128 e_dims = 128 d_dims = 64 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 2 for dev in devices]) models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch = 3 output_ch = 3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) lowest_dense_res = resolution // 16 self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape) self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape) self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape) self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape) self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape) self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape) # Initializing model classes with tf.device (models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape)) self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch))) self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() # Initialize optimizers self.src_dst_opt = nn.TFRMSpropOptimizer(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu ) self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, 4 // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.tf_gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): pred_src_src = nn.tf_concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.tf_concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.tf_concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0) src_loss = nn.tf_average_tensor_list(gpu_src_losses) dst_loss = nn.tf_average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init and self.pretrained_model_path is not None: pretrained_filepath = self.pretrained_model_path / filename if pretrained_filepath.exists(): do_init = not model.load_weights(pretrained_filepath) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution, }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution } ], generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution} ], generators_count=dst_generators_count ) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(floatx="float16" if self.options['use_float16'] else "float32", data_format=self.model_data_format) tf = nn.tf conv_kernel_initializer = nn.initializers.ca() class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs ): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) if self.subpixel: x = nn.tf_space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch*( min(2**i, 8) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.tf_depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp + x, 0.2) return x class UpdownResidualBlock(nn.ModelBase): def on_build(self, ch, inner_ch, kernel_size=3 ): self.up = Upscale (ch, inner_ch, kernel_size=kernel_size) self.res = ResidualBlock (inner_ch, kernel_size=kernel_size) self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False) def forward(self, inp): x = self.up(inp) x = upx = self.res(x) x = self.down(x) x = x + inp x = tf.nn.leaky_relu(x, 0.2) return x, upx class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch, is_hd): self.is_hd=is_hd if self.is_hd: self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1) self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1) self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2) self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2) else: self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) def forward(self, inp): if self.is_hd: x = tf.concat([ nn.tf_flatten(self.down1(inp)), nn.tf_flatten(self.down2(inp)), nn.tf_flatten(self.down3(inp)), nn.tf_flatten(self.down4(inp)) ], -1 ) else: x = nn.tf_flatten(self.down1(inp)) return x class Inter(nn.ModelBase): def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs): self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch super().__init__(**kwargs) def on_build(self): in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) self.upscale1 = Upscale(ae_out_ch, ae_out_ch) def forward(self, inp): x = self.dense1(inp) x = self.dense2(x) x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = self.upscale1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ): self.is_hd = is_hd self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) if is_hd: self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3) self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3) self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3) self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3) else: self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3) self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) def get_weights_ex(self, include_mask): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \ + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights() if include_mask: weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \ + self.out_convm.get_weights() return weights def forward(self, inp): z = inp if self.is_hd: x, upx = self.res0(z) x = self.upscale0(x) x = tf.nn.leaky_relu(x + upx, 0.2) x, upx = self.res1(x) x = self.upscale1(x) x = tf.nn.leaky_relu(x + upx, 0.2) x, upx = self.res2(x) x = self.upscale2(x) x = tf.nn.leaky_relu(x + upx, 0.2) x, upx = self.res3(x) else: x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) return tf.nn.sigmoid(self.out_conv(x)), \ tf.nn.sigmoid(self.out_convm(m)) class CodeDiscriminator(nn.ModelBase): def on_build(self, in_ch, code_res, ch=256): n_downscales = 1 + code_res // 8 self.convs = [] prev_ch = in_ch for i in range(n_downscales): cur_ch = ch * min( (2**i), 8 ) self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=4 if i == 0 else 3, strides=2, padding='SAME', kernel_initializer=conv_kernel_initializer) ) prev_ch = cur_ch self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer) def forward(self, x): for conv in self.convs: x = tf.nn.leaky_relu( conv(x), 0.1 ) return self.out_conv(x) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = self.options['resolution'] learn_mask = self.options['learn_mask'] archi = self.options['archi'] ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0 masked_training = True models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch = 3 output_ch = 3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) lowest_dense_res = resolution // 16 self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape) self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape) self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape) self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape) self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape) self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape) # Initializing model classes with tf.device (models_opt_device): if 'df' in archi: self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd='hd' in archi, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape)) self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch))) self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_src') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: if self.options['true_face_power'] != 0: self.code_discriminator = CodeDiscriminator(ae_dims, code_res=lowest_dense_res*2, name='dis' ) self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] elif 'liae' in archi: self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd='hd' in archi, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape)) self.inter_AB = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') self.inter_B = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch))) inters_out_ch = inter_AB_out_ch+inter_B_out_ch self.decoder = Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_AB, 'inter_AB.npy'], [self.inter_B , 'inter_B.npy'], [self.decoder , 'decoder.npy'] ] if self.is_training: if gan_power != 0: self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=output_ch, base_ch=512, name="D_src") self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=output_ch, base_ch=512, name="D_dst") self.model_filename_list += [ [self.D_src, 'D_src.npy'] ] self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ] # Initialize optimizers lr=5e-5 lr_dropout = 0.3 if self.options['lr_dropout'] else 1.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if 'df' in archi: self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask) elif 'liae' in archi: self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights_ex(learn_mask) self.src_dst_opt.initialize_variables (self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) if self.options['true_face_power'] != 0: self.D_code_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] if gan_power != 0: self.D_src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] # process model tensors if 'df' in archi: gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) elif 'liae' in archi: gpu_src_code = self.encoder (gpu_warped_src) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) gpu_dst_code = self.encoder (gpu_warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if learn_mask: gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) face_style_power = self.options['face_style_power'] / 100.0 if face_style_power != 0 and not self.pretrain: gpu_src_loss += nn.tf_style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) bg_style_power = self.options['bg_style_power'] / 100.0 if bg_style_power != 0 and not self.pretrain: gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] ) gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) if learn_mask: gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLoss(labels,logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) if self.options['true_face_power'] != 0: gpu_src_code_d = self.code_discriminator( gpu_src_code ) gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d) gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \ DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 gpu_D_code_loss_gvs += [ nn.tf_gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] if gan_power != 0: gpu_pred_src_src_d = self.D_src(gpu_pred_src_src_masked_opt) gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) gpu_target_src_d = self.D_src(gpu_target_src_masked_opt) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_pred_dst_dst_d = self.D_dst(gpu_pred_dst_dst_masked_opt) gpu_pred_dst_dst_d_ones = tf.ones_like (gpu_pred_dst_dst_d) gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d) gpu_target_dst_d = self.D_dst(gpu_target_dst_masked_opt) gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d) gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \ (DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \ DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5 gpu_D_src_dst_loss_gvs += [ nn.tf_gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ] gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d)) gpu_G_loss_gvs += [ nn.tf_gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): pred_src_src = nn.tf_concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.tf_concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.tf_concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0) src_loss = nn.tf_average_tensor_list(gpu_src_losses) dst_loss = nn.tf_average_tensor_list(gpu_dst_losses) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.tf_average_gv_list (gpu_G_loss_gvs)) if self.options['true_face_power'] != 0: D_loss_gv_op = self.D_code_opt.get_update_op (nn.tf_average_gv_list(gpu_D_code_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.tf_average_gv_list(gpu_D_src_dst_loss_gvs) ) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train if self.options['true_face_power'] != 0: def D_train(warped_src, warped_dst): nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) self.D_train = D_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm}) self.D_src_dst_train = D_src_dst_train if learn_mask: def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) else: def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_src_dst], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): if 'df' in archi: gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) elif 'liae' in archi: gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) if learn_mask: def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) else: def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if 'df' in archi: if model == self.inter: do_init = True elif 'liae' in archi: if model == self.inter_AB: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types if self.options['face_type'] == 'h': face_type = t.FACE_TYPE_HALF elif self.options['face_type'] == 'mf': face_type = t.FACE_TYPE_MID_FULL elif self.options['face_type'] == 'f': face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if self.options['ct_mode'] != 'none': src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution } ], generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution} ], generators_count=dst_generators_count ) ]) if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True)