def on_initialize(self): nn.initialize() tf = nn.tf class EncBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D(in_ch, out_ch, kernel_size=3, padding='SAME') self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding='VALID' if self.zero_level else 'SAME') def forward(self, x): x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) if not self.zero_level: x = nn.max_pool(x) return x class DecBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding=3 if self.zero_level else 'SAME') self.conv2 = nn.Conv2D(out_ch, out_ch, kernel_size=3, padding='SAME') def forward(self, x): if not self.zero_level: x = nn.upsample2d(x) x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) return x class FromRGB(nn.ModelBase): def on_build(self, out_ch): self.conv1 = nn.Conv2D(3, out_ch, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.leaky_relu(self.conv1(x), 0.2) class ToRGB(nn.ModelBase): def on_build(self, in_ch): self.conv = nn.Conv2D(in_ch, 3, kernel_size=1, padding='SAME') self.convm = nn.Conv2D(in_ch, 1, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.sigmoid(self.conv(x)), tf.nn.sigmoid( self.convm(x)) class Encoder(nn.ModelBase): def on_build(self, e_ch, levels): self.enc_blocks = {} self.from_rgbs = {} self.dense_norm = nn.DenseNorm() in_ch = e_ch out_ch = in_ch for level in range(levels, -1, -1): self.max_ch = out_ch = np.clip(out_ch * 2, 0, 512) self.enc_blocks[level] = EncBlock(in_ch, out_ch, level) self.from_rgbs[level] = FromRGB(in_ch) in_ch = out_ch def forward(self, inp, stage): x = inp for level in range(stage, -1, -1): if stage in self.enc_blocks: if level == stage: x = self.from_rgbs[level](x) x = self.enc_blocks[level](x) x = nn.flatten(x) x = self.dense_norm(x) x = nn.reshape_4D(x, 1, 1, self.max_ch) return x def get_stage_weights(self, stage): self.get_weights() weights = [] for level in range(stage, -1, -1): if stage in self.enc_blocks: if level == stage: weights.append(self.from_rgbs[level].get_weights()) weights.append(self.enc_blocks[level].get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) class Decoder(nn.ModelBase): def on_build(self, d_ch, total_levels, levels_range): self.dec_blocks = {} self.to_rgbs = {} level_ch = {} ch = d_ch for level in range(total_levels, -2, -1): level_ch[level] = ch ch = np.clip(ch * 2, 0, 512) out_ch = level_ch[levels_range[1]] for level in range(levels_range[1], levels_range[0] - 1, -1): in_ch = level_ch[level - 1] self.dec_blocks[level] = DecBlock(in_ch, out_ch, level) self.to_rgbs[level] = ToRGB(out_ch) out_ch = in_ch def forward(self, inp, stage): x = inp for level in range(stage + 1): if level in self.dec_blocks: x = self.dec_blocks[level](x) if level == stage: x = self.to_rgbs[level](x) return x def get_stage_weights(self, stage): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = [] for level in range(stage + 1): if level in self.dec_blocks: weights.append(self.dec_blocks[level].get_weights()) if level == stage: weights.append(self.to_rgbs[level].get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.stage = stage = self.options['stage'] self.start_stage_iter = self.options.get('start_stage_iter', 0) self.target_stage_iter = self.options.get('target_stage_iter', 0) stage_resolutions = [2**(i + 2) for i in range(self.stage_max + 1)] stage_resolution = stage_resolutions[stage] ed_dims = 16 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4 models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_nc = 3 output_nc = 3 bgr_shape = (stage_resolution, stage_resolution, output_nc) mask_shape = (stage_resolution, stage_resolution, 1) self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.warped_dst = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.target_src = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.target_dst = tf.placeholder(tf.float32, (None, ) + bgr_shape) self.target_srcm = tf.placeholder(tf.float32, (None, ) + mask_shape) self.target_dstm = tf.placeholder(tf.float32, (None, ) + mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(e_ch=ed_dims, levels=self.stage_max, name='encoder') self.inter = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[0, 2], name='inter') self.decoder_src = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[3, self.stage_max], name='decoder_src') self.decoder_dst = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[3, self.stage_max], name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: self.src_dst_all_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_stage_weights(stage) + self.inter.get_stage_weights(stage) \ + self.decoder_src.get_stage_weights(stage) \ + self.decoder_dst.get_stage_weights(stage) # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_all_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter( self.encoder(gpu_warped_src, stage), stage) gpu_dst_code = self.inter( self.encoder(gpu_warped_dst, stage), stage) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code, stage) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code, stage) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code, stage) import code code.interact(local=dict(globals(), **locals())) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_srcmasked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_src_dst_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): if gpu_count == 1: pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = gpu_pred_src_dstm_list[0] src_loss = gpu_src_losses[0] dst_loss = gpu_dst_losses[0] src_dst_loss_gv = gpu_src_dst_loss_gvs[0] else: pred_src_src = tf.concat(gpu_pred_src_src_list, 0) pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code, stage=stage) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.pretrain_just_disabled: if model == self.inter: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count - src_generators_count self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution }], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution }], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len( devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf resolution = self.resolution = 96 self.face_type = FaceType.FULL ae_dims = 128 e_dims = 128 d_dims = 64 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all( [dev.total_mem_gb >= 4 for dev in devices]) models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) self.model_filename_list = [] model_archi = nn.DeepFakeArchi(resolution, mod='quick') with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.warped_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_srcm = tf.placeholder(nn.floatx, mask_shape) self.target_dstm = tf.placeholder(nn.floatx, mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels( (nn.floatx, bgr_shape)) self.inter = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') inter_out_ch = self.inter.compute_output_channels( (nn.floatx, (None, encoder_out_ch))) self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: self.src_dst_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, 4 // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_G_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init and self.pretrained_model_path is not None: pretrained_filepath = self.pretrained_model_path / filename if pretrained_filepath.exists(): do_init = not model.load_weights(pretrained_filepath) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len( device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf conv_kernel_initializer = nn.initializers.ca() class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=3, dilations=1, use_activator=True, *kwargs): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D( in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, mod=1, kernel_size=3): self.conv1 = nn.Conv2D( ch, ch * mod, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch * mod, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.1) x = self.conv2(x) x = inp + x x = tf.nn.leaky_relu(x, 0.1) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.conv1 = Downscale(in_ch, e_ch) self.conv2 = Downscale(e_ch, e_ch * 2) self.conv3 = Downscale(e_ch * 2, e_ch * 4) self.conv4 = Downscale(e_ch * 4, e_ch * 8) self.conv5 = Downscale(e_ch * 8, e_ch * 16) self.conv6 = Downscale(e_ch * 16, e_ch * 32) self.conv7 = Downscale(e_ch * 32, e_ch * 64) self.res1 = ResidualBlock(e_ch) self.res2 = ResidualBlock(e_ch * 2) self.res3 = ResidualBlock(e_ch * 4) self.res4 = ResidualBlock(e_ch * 8) self.res5 = ResidualBlock(e_ch * 16) self.res6 = ResidualBlock(e_ch * 32) self.res7 = ResidualBlock(e_ch * 64) def forward(self, inp): x = self.conv1(inp) x = self.res1(x) x = self.conv2(x) x = self.res2(x) x = self.conv3(x) x = self.res3(x) x = self.conv4(x) x = self.res4(x) x = self.conv5(x) x = self.res5(x) x = self.conv6(x) x = self.res6(x) x = self.conv7(x) x = self.res7(x) return x class Inter(nn.ModelBase): def __init__(self, in_ch, ae_ch, **kwargs): self.in_ch, self.ae_ch = in_ch, ae_ch super().__init__(**kwargs) def on_build(self): in_ch, ae_ch = self.in_ch, self.ae_ch self.dense_conv1 = nn.Conv2D( in_ch, 64, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.dense_conv2 = nn.Conv2D( 64, in_ch, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv7 = Upscale(in_ch, in_ch // 2) self.conv6 = Upscale(in_ch // 2, in_ch // 4) def forward(self, inp): x = inp x = self.dense_conv1(x) x = self.dense_conv2(x) x = self.conv7(x) x = self.conv6(x) return x class Decoder(nn.ModelBase): def on_build(self, in_ch): self.upscale6 = Upscale(in_ch, in_ch // 2) self.upscale5 = Upscale(in_ch // 2, in_ch // 4) self.upscale4 = Upscale(in_ch // 4, in_ch // 8) self.upscale3 = Upscale(in_ch // 8, in_ch // 16) self.upscale2 = Upscale(in_ch // 16, in_ch // 32) #self.upscale1 = Upscale(in_ch//32, in_ch//64) self.out_conv = nn.Conv2D( in_ch // 32, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.res61 = ResidualBlock(in_ch // 2, mod=8) self.res62 = ResidualBlock(in_ch // 2, mod=8) self.res63 = ResidualBlock(in_ch // 2, mod=8) self.res51 = ResidualBlock(in_ch // 4, mod=8) self.res52 = ResidualBlock(in_ch // 4, mod=8) self.res53 = ResidualBlock(in_ch // 4, mod=8) self.res41 = ResidualBlock(in_ch // 8, mod=8) self.res42 = ResidualBlock(in_ch // 8, mod=8) self.res43 = ResidualBlock(in_ch // 8, mod=8) self.res31 = ResidualBlock(in_ch // 16, mod=8) self.res32 = ResidualBlock(in_ch // 16, mod=8) self.res33 = ResidualBlock(in_ch // 16, mod=8) self.res21 = ResidualBlock(in_ch // 32, mod=8) self.res22 = ResidualBlock(in_ch // 32, mod=8) self.res23 = ResidualBlock(in_ch // 32, mod=8) m_ch = in_ch // 2 self.upscalem6 = Upscale(in_ch, m_ch // 2) self.upscalem5 = Upscale(m_ch // 2, m_ch // 4) self.upscalem4 = Upscale(m_ch // 4, m_ch // 8) self.upscalem3 = Upscale(m_ch // 8, m_ch // 16) self.upscalem2 = Upscale(m_ch // 16, m_ch // 32) #self.upscalem1 = Upscale(m_ch//32, m_ch//64) self.out_convm = nn.Conv2D( m_ch // 32, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): z = inp x = self.upscale6(z) x = self.res61(x) x = self.res62(x) x = self.res63(x) x = self.upscale5(x) x = self.res51(x) x = self.res52(x) x = self.res53(x) x = self.upscale4(x) x = self.res41(x) x = self.res42(x) x = self.res43(x) x = self.upscale3(x) x = self.res31(x) x = self.res32(x) x = self.res33(x) x = self.upscale2(x) x = self.res21(x) x = self.res22(x) x = self.res23(x) #x = self.upscale1 (x) y = self.upscalem6(z) y = self.upscalem5(y) y = self.upscalem4(y) y = self.upscalem3(y) y = self.upscalem2(y) #y = self.upscalem1 (y) return tf.nn.sigmoid(self.out_conv(x)), \ tf.nn.sigmoid(self.out_convm(y)) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices resolution = self.resolution = 128 ae_dims = 128 e_dims = 16 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all( [dev.total_mem_gb >= 4 for dev in devices]) models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 output_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.warped_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_srcm = tf.placeholder(nn.floatx, mask_shape) self.target_dstm = tf.placeholder(nn.floatx, mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') self.inter = Inter(in_ch=e_dims * 64, ae_ch=ae_dims, name='inter') self.decoder_src = Decoder(in_ch=e_dims * 16, name='decoder_src') self.decoder_dst = Decoder(in_ch=e_dims * 16, name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: self.src_dst_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=5e-5, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, 4 // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_G_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init and self.pretrained_model_path is not None: pretrained_filepath = self.pretrained_model_path / filename if pretrained_filepath.exists(): do_init = not model.load_weights(pretrained_filepath) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution, }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'data_format': nn.data_format, 'resolution': resolution }], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'data_format': nn.data_format, 'resolution': resolution }], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len( devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf class EncBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D(in_ch, out_ch, kernel_size=3, padding='SAME') self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding='VALID' if self.zero_level else 'SAME') def forward(self, x): x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) if not self.zero_level: x = nn.max_pool(x) #if self.zero_level: return x class DecBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=4 if self.zero_level else 3, padding=3 if self.zero_level else 'SAME') self.conv2 = nn.Conv2D(out_ch, out_ch, kernel_size=3, padding='SAME') def forward(self, x): if not self.zero_level: x = nn.upsample2d(x) x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) return x class InterBlock(nn.ModelBase): def on_build(self, in_ch, out_ch, level): self.zero_level = level == 0 self.dense1 = nn.Dense() def forward(self, x): x = tf.nn.leaky_relu(self.conv1(x), 0.2) x = tf.nn.leaky_relu(self.conv2(x), 0.2) if not self.zero_level: x = nn.max_pool(x) #if self.zero_level: return x class FromRGB(nn.ModelBase): def on_build(self, out_ch): self.conv1 = nn.Conv2D(3, out_ch, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.leaky_relu(self.conv1(x), 0.2) class ToRGB(nn.ModelBase): def on_build(self, in_ch): self.conv = nn.Conv2D(in_ch, 3, kernel_size=1, padding='SAME') self.convm = nn.Conv2D(in_ch, 1, kernel_size=1, padding='SAME') def forward(self, x): return tf.nn.sigmoid(self.conv(x)), tf.nn.sigmoid( self.convm(x)) ed_dims = 16 ae_res = 4 level_chs = { i - 1: v for i, v in enumerate([ np.clip(ed_dims * (2**i), 0, 512) for i in range(self.stage_max + 2) ][::-1]) } ae_ch = level_chs[0] class Encoder(nn.ModelBase): def on_build(self, e_ch, levels): self.enc_blocks = {} self.from_rgbs = {} self.dense_norm = nn.DenseNorm() for level in range(levels, -1, -1): self.from_rgbs[level] = FromRGB(level_chs[level]) if level != 0: self.enc_blocks[level] = EncBlock( level_chs[level], level_chs[level - 1], level) self.ae_dense1 = nn.Dense(ae_res * ae_res * ae_ch, 256) self.ae_dense2 = nn.Dense(256, ae_res * ae_res * ae_ch) def forward(self, stage, inp, prev_inp=None, alpha=None): x = inp for level in range(stage, -1, -1): if stage in self.from_rgbs: if level == stage: x = self.from_rgbs[level](x) elif level == stage - 1: x = x * alpha + self.from_rgbs[level](prev_inp) * ( 1 - alpha) if level != 0: x = self.enc_blocks[level](x) x = nn.flatten(x) x = self.dense_norm(x) x = self.ae_dense1(x) x = self.ae_dense2(x) x = nn.reshape_4D(x, ae_res, ae_res, ae_ch) return x def get_stage_weights(self, stage): self.get_weights() weights = [] for level in range(stage, -1, -1): if stage in self.from_rgbs: if level == stage or level == stage - 1: weights.append(self.from_rgbs[level].get_weights()) if level != 0: weights.append( self.enc_blocks[level].get_weights()) weights.append(self.ae_dense1.get_weights()) weights.append(self.ae_dense2.get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) class Decoder(nn.ModelBase): def on_build(self, levels_range): self.dec_blocks = {} self.to_rgbs = {} for level in range(levels_range[0], levels_range[1] + 1): self.to_rgbs[level] = ToRGB(level_chs[level]) if level != 0: self.dec_blocks[level] = DecBlock( level_chs[level - 1], level_chs[level], level) def forward(self, stage, inp, alpha=None, inter=None): x = inp for level in range(stage + 1): if level in self.to_rgbs: if level == stage and stage > 0: prev_level = level - 1 #prev_x, prev_xm = (inter.to_rgbs[prev_level] if inter is not None and prev_level in inter.to_rgbs else self.to_rgbs[prev_level])(x) prev_x, prev_xm = self.to_rgbs[prev_level](x) prev_x = nn.upsample2d(prev_x) prev_xm = nn.upsample2d(prev_xm) if level != 0: x = self.dec_blocks[level](x) if level == stage: x, xm = self.to_rgbs[level](x) if stage > 0: x = x * alpha + prev_x * (1 - alpha) xm = xm * alpha + prev_xm * (1 - alpha) return x, xm return x def get_stage_weights(self, stage): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = [] for level in range(stage + 1): if level in self.to_rgbs: if level != 0: weights.append( self.dec_blocks[level].get_weights()) if level == stage or level == stage - 1: weights.append(self.to_rgbs[level].get_weights()) if len(weights) == 0: return [] elif len(weights) == 1: return weights[0] else: return sum(weights[1:], weights[0]) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.stage = stage = self.options['stage'] self.start_stage_iter = self.options.get('start_stage_iter', 0) self.target_stage_iter = self.options.get('target_stage_iter', 0) resolution = self.options['resolution'] stage_resolutions = [2**(i + 2) for i in range(self.stage_max + 1)] stage_resolution = stage_resolutions[stage] prev_stage = stage - 1 if stage != 0 else stage prev_stage_resolution = stage_resolutions[ stage - 1] if stage != 0 else stage_resolution self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4 models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_nc = 3 output_nc = 3 prev_bgr_shape = nn.get4Dshape(prev_stage_resolution, prev_stage_resolution, output_nc) bgr_shape = nn.get4Dshape(stage_resolution, stage_resolution, output_nc) mask_shape = nn.get4Dshape(stage_resolution, stage_resolution, 1) self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.prev_warped_src = tf.placeholder(tf.float32, prev_bgr_shape) self.warped_src = tf.placeholder(tf.float32, bgr_shape) self.prev_warped_dst = tf.placeholder(tf.float32, prev_bgr_shape) self.warped_dst = tf.placeholder(tf.float32, bgr_shape) self.target_src = tf.placeholder(tf.float32, bgr_shape) self.target_dst = tf.placeholder(tf.float32, bgr_shape) self.target_srcm = tf.placeholder(tf.float32, mask_shape) self.target_dstm = tf.placeholder(tf.float32, mask_shape) self.alpha_t = tf.placeholder(tf.float32, (None, 1, 1, 1)) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(e_ch=ed_dims, levels=self.stage_max, name='encoder') #self.inter = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[0,2], name='inter') self.decoder_src = Decoder(levels_range=[0, self.stage_max], name='decoder_src') self.decoder_dst = Decoder(levels_range=[0, self.stage_max], name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy'], #[self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: self.src_dst_all_weights = self.encoder.get_weights( ) + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_stage_weights(stage) \ + self.decoder_src.get_stage_weights(stage) \ + self.decoder_dst.get_stage_weights(stage) # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=1.0, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.src_dst_all_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_prev_warped_src = self.prev_warped_src[ batch_slice, :, :, :] gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_prev_warped_dst = self.prev_warped_dst[ batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm = self.target_srcm[ batch_slice, :, :, :] gpu_target_dstm = self.target_dstm[ batch_slice, :, :, :] gpu_alpha_t = self.alpha_t[batch_slice, :, :, :] # process model tensors #gpu_src_code = self.inter(stage, self.encoder(stage, gpu_warped_src, gpu_prev_warped_src, gpu_alpha_t), gpu_alpha_t ) #gpu_dst_code = self.inter(stage, self.encoder(stage, gpu_warped_dst, gpu_prev_warped_dst, gpu_alpha_t), gpu_alpha_t ) gpu_src_code = self.encoder(stage, gpu_warped_src, gpu_prev_warped_src, gpu_alpha_t) gpu_dst_code = self.encoder(stage, gpu_warped_dst, gpu_prev_warped_dst, gpu_alpha_t) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( stage, gpu_src_code, gpu_alpha_t) #, inter=self.inter) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( stage, gpu_dst_code, gpu_alpha_t) #, inter=self.inter) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( stage, gpu_dst_code, gpu_alpha_t) #, inter=self.inter) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, stage_resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, stage_resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_srcmasked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * tf.square(gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) gpu_src_loss += tf.reduce_mean( tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) if stage_resolution >= 16: gpu_src_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 11.6)), axis=[1]) if stage_resolution >= 32: gpu_src_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 23.2)), axis=[1]) gpu_dst_loss = tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) gpu_dst_loss += tf.reduce_mean( tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) if stage_resolution >= 16: gpu_dst_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 11.6)), axis=[1]) if stage_resolution >= 32: gpu_dst_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(stage_resolution / 23.2)), axis=[1]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients(gpu_src_dst_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): if gpu_count == 1: pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = gpu_pred_src_dstm_list[0] src_loss = gpu_src_losses[0] dst_loss = gpu_dst_losses[0] src_dst_loss_gv = gpu_src_dst_loss_gvs[0] else: pred_src_src = tf.concat(gpu_pred_src_src_list, 0) pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( src_dst_loss_gv) # Initializing training and view functions def get_alpha(batch_size): alpha = 0 if self.stage != 0: alpha = (self.iter - self.start_stage_iter) / ( self.target_stage_iter - self.start_stage_iter) alpha = np.clip(alpha, 0, 1) alpha = np.array([alpha], nn.floatx.as_numpy_dtype).reshape( (1, 1, 1, 1)) alpha = np.repeat(alpha, batch_size, 0) return alpha def src_dst_train(prev_warped_src, warped_src, target_src, target_srcm, \ prev_warped_dst, warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.prev_warped_src: prev_warped_src, self.warped_src: warped_src, self.target_src: target_src, self.target_srcm: target_srcm, self.prev_warped_dst: prev_warped_dst, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm: target_dstm, self.alpha_t: get_alpha(prev_warped_src.shape[0]) }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(prev_warped_src, warped_src, prev_warped_dst, warped_dst): return nn.tf_sess.run( [ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.prev_warped_src: prev_warped_src, self.warped_src: warped_src, self.prev_warped_dst: prev_warped_dst, self.warped_dst: warped_dst, self.alpha_t: get_alpha(prev_warped_src.shape[0]) }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code, stage=stage) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.pretrain_just_disabled: if model == self.inter: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init: model.init_weights() # initializing sample generators if self.is_training: self.face_type = FaceType.FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count - src_generators_count self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, ], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': prev_stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_mask_type': SampleProcessor.FaceMaskType.FULL_FACE, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution, 'nearest_resize_to': stage_resolution }, ], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len( device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf conv_kernel_initializer = nn.initializers.ca() class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) if self.subpixel: x = nn.space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch * (min(2**i, 8)) self.downs.append( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel)) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D( in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp + x, 0.2) return x class UpdownResidualBlock(nn.ModelBase): def on_build(self, ch, inner_ch, kernel_size=3): self.up = Upscale(ch, inner_ch, kernel_size=kernel_size) self.res = ResidualBlock(inner_ch, kernel_size=kernel_size) self.down = Downscale(inner_ch, ch, kernel_size=kernel_size, use_activator=False) def forward(self, inp): x = self.up(inp) x = upx = self.res(x) x = self.down(x) x = x + inp x = tf.nn.leaky_relu(x, 0.2) return x, upx class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=5, kernel_size=5, dilations=1, subpixel=False) def forward(self, inp): x = nn.flatten(self.down1(inp)) return x class Inter(nn.ModelBase): def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs): self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch super().__init__(**kwargs) def on_build(self): in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch self.dense1 = nn.Dense(in_ch, ae_ch) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch * 2) def forward(self, inp): x = self.dense1(inp) x = self.dense2(x) x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = self.upscale1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch, d_mask_ch): self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3) self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3) self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3) self.upscale3 = Upscale(d_ch * 2, d_ch * 1, kernel_size=3) self.res0 = ResidualBlock(d_ch * 8, kernel_size=3) self.res1 = ResidualBlock(d_ch * 4, kernel_size=3) self.res2 = ResidualBlock(d_ch * 2, kernel_size=3) self.res3 = ResidualBlock(d_ch * 1, kernel_size=3) self.out_conv = nn.Conv2D( d_ch * 1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.upscalem0 = Upscale(in_ch, d_mask_ch * 8, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch * 8, d_mask_ch * 4, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch * 4, d_mask_ch * 2, kernel_size=3) self.upscalem3 = Upscale(d_mask_ch * 2, d_mask_ch * 1, kernel_size=3) self.out_convm = nn.Conv2D( d_mask_ch * 1, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) """ def get_weights_ex(self, include_mask): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \ + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights() if include_mask: weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \ + self.out_convm.get_weights() return weights """ def get_weights_ex(self, include_mask): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() + self.upscale3.get_weights()\ + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.res3.get_weights() + self.out_conv.get_weights() if include_mask: weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() + self.upscalem3.get_weights() \ + self.out_convm.get_weights() return weights def forward(self, inp): z = inp x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) x = self.upscale3(x) x = self.res3(x) m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) m = self.upscalem3(m) return tf.nn.sigmoid(self.out_conv(x)), \ tf.nn.sigmoid(self.out_convm(m)) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 448 self.learn_mask = learn_mask = True eyes_prio = True ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = False self.pretrain_just_disabled = False if self.pretrain_just_disabled: self.set_iter(0) self.gan_power = gan_power = self.options[ 'gan_power'] if not self.pretrain else 0.0 masked_training = True models_opt_on_gpu = False if len(devices) == 0 else True if len( devices) > 1 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 output_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) lowest_dense_res = resolution // 32 self.model_filename_list = [] with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.warped_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.target_srcm_all = tf.placeholder(nn.floatx, mask_shape) self.target_dstm_all = tf.placeholder(nn.floatx, mask_shape) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels( (nn.floatx, bgr_shape)) self.inter = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') inter_out_ch = self.inter.compute_output_channels( (nn.floatx, (None, encoder_out_ch))) self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.inter, 'inter.npy'], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy']] if self.is_training: if gan_power != 0: self.D_src = nn.PatchDiscriminator(patch_size=resolution // 16, in_ch=output_ch, base_ch=256, name="D_src") self.D_dst = nn.PatchDiscriminator(patch_size=resolution // 16, in_ch=output_ch, base_ch=256, name="D_dst") self.model_filename_list += [[self.D_src, 'D_src.npy']] self.model_filename_list += [[self.D_dst, 'D_dst.npy']] # Initialize optimizers lr = 5e-5 clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.src_dst_opt = nn.RMSprop(lr=lr, clipnorm=clipnorm, name='src_dst_opt') self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] self.src_dst_all_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights( ) + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights( ) + self.inter.get_weights() + self.decoder_src.get_weights_ex( learn_mask) + self.decoder_dst.get_weights_ex(learn_mask) self.src_dst_opt.initialize_variables( self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) if gan_power != 0: self.D_src_dst_opt = nn.RMSprop(lr=lr, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt.initialize_variables( self.D_src.get_weights() + self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.D_src_dst_opt, 'D_src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_warped_dst = self.warped_dst[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_target_srcm_all = self.target_srcm_all[ batch_slice, :, :, :] gpu_target_dstm_all = self.target_dstm_all[ batch_slice, :, :, :] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src( gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst( gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) # unpack masks from one combined mask gpu_target_srcm = tf.clip_by_value(gpu_target_srcm_all, 0, 1) gpu_target_dstm = tf.clip_by_value(gpu_target_dstm_all, 0, 1) gpu_target_srcm_eyes = tf.clip_by_value( gpu_target_srcm_all - 1, 0, 1) gpu_target_dstm_eyes = tf.clip_by_value( gpu_target_dstm_all - 1, 0, 1) gpu_target_srcm_blur = nn.gaussian_blur( gpu_target_srcm, max(1, resolution // 32)) gpu_target_dstm_blur = nn.gaussian_blur( gpu_target_dstm, max(1, resolution // 32)) gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst * ( 1.0 - gpu_target_dstm_blur) gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * ( 1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), axis=[1, 2, 3]) if eyes_prio: gpu_src_loss += tf.reduce_mean( 300 * tf.abs(gpu_target_src * gpu_target_srcm_eyes - gpu_pred_src_src * gpu_target_srcm_eyes), axis=[1, 2, 3]) if learn_mask: gpu_src_loss += tf.reduce_mean( 10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm), axis=[1, 2, 3]) gpu_dst_loss = tf.reduce_mean( 10 * nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt), axis=[1, 2, 3]) if eyes_prio: gpu_dst_loss += tf.reduce_mean( 300 * tf.abs(gpu_target_dst * gpu_target_dstm_eyes - gpu_pred_dst_dst * gpu_target_dstm_eyes), axis=[1, 2, 3]) if learn_mask: gpu_dst_loss += tf.reduce_mean( 10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm), axis=[1, 2, 3]) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLoss(labels, logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits), axis=[1, 2, 3]) if gan_power != 0: gpu_pred_src_src_d = self.D_src( gpu_pred_src_src_masked_opt) gpu_pred_src_src_d_ones = tf.ones_like( gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like( gpu_pred_src_src_d) gpu_target_src_d = self.D_src( gpu_target_src_masked_opt) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_pred_dst_dst_d = self.D_dst( gpu_pred_dst_dst_masked_opt) gpu_pred_dst_dst_d_ones = tf.ones_like( gpu_pred_dst_dst_d) gpu_pred_dst_dst_d_zeros = tf.zeros_like( gpu_pred_dst_dst_d) gpu_target_dst_d = self.D_dst( gpu_target_dst_masked_opt) gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d) gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \ (DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \ DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5 gpu_D_src_dst_loss_gvs += [ nn.gradients( gpu_D_src_dst_loss, self.D_src.get_weights() + self.D_dst.get_weights()) ] gpu_G_loss += gan_power * ( DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d)) gpu_G_loss_gvs += [ nn.gradients(gpu_G_loss, self.src_dst_trainable_weights) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv_op = self.src_dst_opt.get_update_op( nn.average_gv_list(gpu_G_loss_gvs)) if gan_power != 0: src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op( nn.average_gv_list(gpu_D_src_dst_loss_gvs)) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): s, d, _ = nn.tf_sess.run( [src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm_all: target_srcm_all, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm_all: target_dstm_all, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train if gan_power != 0: def D_src_dst_train(warped_src, target_src, target_srcm_all, \ warped_dst, target_dst, target_dstm_all): nn.tf_sess.run( [src_D_src_dst_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.target_srcm_all: target_srcm_all, self.warped_dst: warped_dst, self.target_dst: target_dst, self.target_dstm_all: target_dstm_all }) self.D_src_dst_train = D_src_dst_train if learn_mask: def AE_view(warped_src, warped_dst): return nn.tf_sess.run([ pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm ], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) else: def AE_view(warped_src, warped_dst): return nn.tf_sess.run( [pred_src_src, pred_dst_dst, pred_src_dst], feed_dict={ self.warped_src: warped_src, self.warped_dst: warped_dst }) self.AE_view = AE_view else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) if learn_mask: def AE_merge(warped_dst): return nn.tf_sess.run([ gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm ], feed_dict={ self.warped_dst: warped_dst }) else: def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_HEAD training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) t_img_warped = t.IMG_WARPED_TRANSFORMED cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'types': (t_img_warped, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'types': (t_img_warped, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format': nn.data_format, 'resolution': resolution }, { 'types': (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count) ]) if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True)
def on_initialize(self): device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.model_data_format = "NCHW" if len( devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf resolution = self.resolution = 96 self.face_type = FaceType.FULL ae_dims = 256 e_dims = 64 d_dims = 64 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all( [dev.total_mem_gb >= 4 for dev in devices]) models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device == '/CPU:0' input_ch = 3 bgr_shape = nn.get4Dshape(resolution, resolution, input_ch) mask_shape = nn.get4Dshape(resolution, resolution, 1) self.model_filename_list = [] kernel_initializer = tf.initializers.glorot_uniform() class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D(in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3): self.conv1 = nn.Conv2D(ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) self.conv2 = nn.Conv2D(ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp + x, 0.2) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down11 = nn.Conv2D(in_ch, e_ch, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down12 = nn.Conv2D(e_ch, e_ch, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down21 = nn.Conv2D(e_ch, e_ch * 2, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down22 = nn.Conv2D(e_ch * 2, e_ch * 2, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down31 = nn.Conv2D(e_ch * 2, e_ch * 4, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down32 = nn.Conv2D(e_ch * 4, e_ch * 4, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down33 = nn.Conv2D(e_ch * 4, e_ch * 4, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down41 = nn.Conv2D(e_ch * 4, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down42 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down43 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down51 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down52 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) self.down53 = nn.Conv2D(e_ch * 8, e_ch * 8, kernel_size=3, strides=1, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, inp): x = inp x = self.down11(x) x = self.down12(x) x = nn.max_pool(x) x = self.down21(x) x = self.down22(x) x = nn.max_pool(x) x = self.down31(x) x = self.down32(x) x = self.down33(x) x = nn.max_pool(x) x = self.down41(x) x = self.down42(x) x = self.down43(x) x = nn.max_pool(x) x = self.down51(x) x = self.down52(x) x = self.down53(x) x = nn.max_pool(x) x = nn.flatten(x) return x class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs): self.conv1 = nn.Conv2D(self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=kernel_initializer) def forward(self, x): x = self.conv1(x) if self.subpixel: x = nn.space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch * (min(2**i, 8)) self.downs.append( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel)) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3): self.conv1 = nn.Conv2D(in_ch, out_ch * 4, kernel_size=kernel_size, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, x): x = self.conv1(x) x = tf.nn.leaky_relu(x, 0.1) x = nn.depth_to_space(x, 2) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) def forward(self, inp): x = nn.flatten(self.down1(inp)) return x class Branch(nn.ModelBase): def on_build(self, in_ch, ae_ch): self.dense1 = nn.Dense(in_ch, ae_ch) def forward(self, inp): x = self.dense1(inp) return x class Classifier(nn.ModelBase): def on_build(self, in_ch, n_classes): self.dense1 = nn.Dense(in_ch, 4096) self.dense2 = nn.Dense(4096, 4096) self.pitch_dense = nn.Dense(4096, n_classes) self.yaw_dense = nn.Dense(4096, n_classes) def forward(self, inp): x = inp x = self.dense1(x) x = self.dense2(x) return self.pitch_dense(x), self.yaw_dense(x) lowest_dense_res = resolution // 16 class Inter(nn.ModelBase): def on_build(self, in_ch, ae_out_ch): self.ae_out_ch = ae_out_ch self.dense2 = nn.Dense( in_ch, lowest_dense_res * lowest_dense_res * ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch) def forward(self, inp): x = inp x = self.dense2(x) x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = self.upscale1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch, d_mask_ch): self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3) self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3) self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3) self.res0 = ResidualBlock(d_ch * 8, kernel_size=3) self.res1 = ResidualBlock(d_ch * 4, kernel_size=3) self.res2 = ResidualBlock(d_ch * 2, kernel_size=3) self.out_conv = nn.Conv2D( d_ch * 2, 3, kernel_size=1, padding='SAME', kernel_initializer=kernel_initializer) def forward(self, inp): z = inp x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) return tf.nn.sigmoid(self.out_conv(x)) n_pyr_degs = self.n_pyr_degs = 3 n_pyr_classes = self.n_pyr_classes = 180 // self.n_pyr_degs with tf.device('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder(nn.floatx, bgr_shape) self.target_src = tf.placeholder(nn.floatx, bgr_shape) self.target_dst = tf.placeholder(nn.floatx, bgr_shape) self.pitches_vector = tf.placeholder(nn.floatx, (None, n_pyr_classes)) self.yaws_vector = tf.placeholder(nn.floatx, (None, n_pyr_classes)) # Initializing model classes with tf.device(models_opt_device): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_channels( (nn.floatx, bgr_shape)) self.bT = Branch(in_ch=encoder_out_ch, ae_ch=ae_dims, name='bT') self.bP = Branch(in_ch=encoder_out_ch, ae_ch=ae_dims, name='bP') self.bTC = Classifier(in_ch=ae_dims, n_classes=self.n_pyr_classes, name='bTC') self.bPC = Classifier(in_ch=ae_dims, n_classes=self.n_pyr_classes, name='bPC') self.inter = Inter(in_ch=ae_dims * 2, ae_out_ch=ae_dims * 2, name='inter') self.decoder = Decoder(in_ch=ae_dims * 2, d_ch=d_dims, d_mask_ch=d_dims, name='decoder') self.model_filename_list += [[self.encoder, 'encoder.npy'], [self.bT, 'bT.npy'], [self.bTC, 'bTC.npy'], [self.bP, 'bP.npy'], [self.bPC, 'bPC.npy'], [self.inter, 'inter.npy'], [self.decoder, 'decoder.npy']] if self.is_training: self.all_trainable_weights = self.encoder.get_weights() + \ self.bT.get_weights() +\ self.bTC.get_weights() +\ self.bP.get_weights() +\ self.bPC.get_weights() +\ self.inter.get_weights() +\ self.decoder.get_weights() # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=5e-5, name='src_dst_opt') self.src_dst_opt.initialize_variables( self.all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [(self.src_dst_opt, 'src_dst_opt.npy')] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, 32 // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_src_list = [] gpu_pred_dst_list = [] gpu_A_losses = [] gpu_B_losses = [] gpu_C_losses = [] gpu_D_losses = [] gpu_A_loss_gvs = [] gpu_B_loss_gvs = [] gpu_C_loss_gvs = [] gpu_D_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src[batch_slice, :, :, :] gpu_target_src = self.target_src[batch_slice, :, :, :] gpu_target_dst = self.target_dst[batch_slice, :, :, :] gpu_pitches_vector = self.pitches_vector[ batch_slice, :] gpu_yaws_vector = self.yaws_vector[batch_slice, :] # process model tensors gpu_src_enc_code = self.encoder(gpu_warped_src) gpu_dst_enc_code = self.encoder(gpu_target_dst) gpu_src_bT_code = self.bT(gpu_src_enc_code) gpu_src_bT_code_ng = tf.stop_gradient(gpu_src_bT_code) gpu_src_T_pitch, gpu_src_T_yaw = self.bTC(gpu_src_bT_code) gpu_dst_bT_code = self.bT(gpu_dst_enc_code) gpu_src_bP_code = self.bP(gpu_src_enc_code) gpu_src_P_pitch, gpu_src_P_yaw = self.bPC(gpu_src_bP_code) def crossentropy(target, output): output = tf.nn.softmax(output) output = tf.clip_by_value(output, 1e-7, 1 - 1e-7) return tf.reduce_sum(target * -tf.log(output), axis=-1, keepdims=False) def negative_crossentropy(n_classes, output): output = tf.nn.softmax(output) output = tf.clip_by_value(output, 1e-7, 1 - 1e-7) return (1.0 / n_classes) * tf.reduce_sum( tf.log(output), axis=-1, keepdims=False) gpu_src_bT_code_n = gpu_src_bT_code_ng + tf.random.normal( tf.shape(gpu_src_bT_code_ng)) gpu_src_bP_code_n = gpu_src_bP_code + tf.random.normal( tf.shape(gpu_src_bP_code)) gpu_pred_src = self.decoder( self.inter( tf.concat([gpu_src_bT_code_ng, gpu_src_bP_code], axis=-1))) gpu_pred_src_n = self.decoder( self.inter( tf.concat([gpu_src_bT_code_n, gpu_src_bP_code_n], axis=-1))) gpu_pred_dst = self.decoder( self.inter( tf.concat([gpu_dst_bT_code, gpu_src_bP_code], axis=-1))) gpu_A_loss = 1.0*crossentropy(gpu_pitches_vector, gpu_src_T_pitch ) + \ 1.0*crossentropy(gpu_yaws_vector, gpu_src_T_yaw ) gpu_B_loss = 0.1*crossentropy(gpu_pitches_vector, gpu_src_P_pitch ) + \ 0.1*crossentropy(gpu_yaws_vector, gpu_src_P_yaw ) gpu_C_loss = 0.1*negative_crossentropy( n_pyr_classes, gpu_src_P_pitch ) + \ 0.1*negative_crossentropy( n_pyr_classes, gpu_src_P_yaw ) gpu_D_loss = 0.0000001*(\ 0.5*tf.reduce_sum(tf.square(gpu_target_src-gpu_pred_src), axis=[1,2,3]) + \ 0.5*tf.reduce_sum(tf.square(gpu_target_src-gpu_pred_src_n), axis=[1,2,3]) ) gpu_pred_src_list.append(gpu_pred_src) gpu_pred_dst_list.append(gpu_pred_dst) gpu_A_losses += [gpu_A_loss] gpu_B_losses += [gpu_B_loss] gpu_C_losses += [gpu_C_loss] gpu_D_losses += [gpu_D_loss] A_weights = self.encoder.get_weights( ) + self.bT.get_weights() + self.bTC.get_weights() B_weights = self.bPC.get_weights() C_weights = self.encoder.get_weights( ) + self.bP.get_weights() D_weights = self.inter.get_weights( ) + self.decoder.get_weights() gpu_A_loss_gvs += [nn.gradients(gpu_A_loss, A_weights)] gpu_B_loss_gvs += [nn.gradients(gpu_B_loss, B_weights)] gpu_C_loss_gvs += [nn.gradients(gpu_C_loss, C_weights)] gpu_D_loss_gvs += [nn.gradients(gpu_D_loss, D_weights)] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred_src = nn.concat(gpu_pred_src_list, 0) pred_dst = nn.concat(gpu_pred_dst_list, 0) A_loss = nn.average_tensor_list(gpu_A_losses) B_loss = nn.average_tensor_list(gpu_B_losses) C_loss = nn.average_tensor_list(gpu_C_losses) D_loss = nn.average_tensor_list(gpu_D_losses) A_loss_gv = nn.average_gv_list(gpu_A_loss_gvs) B_loss_gv = nn.average_gv_list(gpu_B_loss_gvs) C_loss_gv = nn.average_gv_list(gpu_C_loss_gvs) D_loss_gv = nn.average_gv_list(gpu_D_loss_gvs) A_loss_gv_op = self.src_dst_opt.get_update_op(A_loss_gv) B_loss_gv_op = self.src_dst_opt.get_update_op(B_loss_gv) C_loss_gv_op = self.src_dst_opt.get_update_op(C_loss_gv) D_loss_gv_op = self.src_dst_opt.get_update_op(D_loss_gv) # Initializing training and view functions def A_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [A_loss, A_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.A_train = A_train def B_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [B_loss, B_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.B_train = B_train def C_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [C_loss, C_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.C_train = C_train def D_train(warped_src, target_src, pitches_vector, yaws_vector): l, _ = nn.tf_sess.run( [D_loss, D_loss_gv_op], feed_dict={ self.warped_src: warped_src, self.target_src: target_src, self.pitches_vector: pitches_vector, self.yaws_vector: yaws_vector }) return np.mean(l) self.D_train = D_train def AE_view(warped_src): return nn.tf_sess.run([pred_src], feed_dict={self.warped_src: warped_src}) self.AE_view = AE_view def AE_view2(warped_src, target_dst): return nn.tf_sess.run([pred_dst], feed_dict={ self.warped_src: warped_src, self.target_dst: target_dst }) self.AE_view2 = AE_view2 else: # Initializing merge function with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src( gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) def AE_merge(warped_dst): return nn.tf_sess.run( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst: warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator( self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False if model == self.inter: do_init = True else: do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename)) if do_init and self.pretrained_model_path is not None: pretrained_filepath = self.pretrained_model_path / filename if pretrained_filepath.exists(): do_init = not model.load_weights(pretrained_filepath) if do_init: model.init_weights() # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path( ) training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path( ) cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 self.set_training_data_generators([ SampleGeneratorFace( training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.PITCH_YAW_ROLL_SIGMOID, 'resolution': resolution }, ], generators_count=src_generators_count), SampleGeneratorFace( training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True if self.pretrain else False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.PITCH_YAW_ROLL_SIGMOID, 'resolution': resolution }, ], generators_count=dst_generators_count) ]) self.last_samples = None
def on_initialize(self): nn.initialize() tf = nn.tf nn.set_floatx(tf.float32) conv_kernel_initializer = nn.initializers.ca class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilations = dilations self.subpixel = subpixel self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs ): self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer ) def forward(self, x): x = self.conv1(x) if self.subpixel: x = tf.nn.space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) return x def get_out_ch(self): return (self.out_ch // 4) * 4 class DownscaleBlock(nn.ModelBase): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch*( min(2**i, 8) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): x = inp for down in self.downs: x = down(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, x): x = self.conv1(x) x = x = tf.nn.leaky_relu(x, 0.1) x = tf.nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) def forward(self, inp): x = self.conv1(inp) x = x = tf.nn.leaky_relu(x, 0.1) x = self.conv2(x) x = inp + x x = x = tf.nn.leaky_relu(x, 0.1) return x class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch): self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5) def forward(self, inp): return nn.flatten(self.down1(inp)) class Inter(nn.ModelBase): def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch, **kwargs): self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch super().__init__(**kwargs) def on_build(self): in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal ) self.upscale1 = Upscale(ae_out_ch, d_ch*8) self.res1 = ResidualBlock(d_ch*8) def forward(self, inp): x = self.dense1(inp) x = self.dense2(x) x = tf.reshape (x, (-1, lowest_dense_res, lowest_dense_res, self.ae_out_ch)) x = self.upscale1(x) x = self.res1(x) return x def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): def on_build(self, in_ch, d_ch): self.upscale0_1 = Upscale(in_ch, d_ch*1) self.upscale0_2 = Upscale(d_ch*1, d_ch*1) self.upscale0_3 = Upscale(d_ch*1, d_ch*1) self.out0 = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.upscale1_1 = Upscale(in_ch, d_ch*1) self.upscale1_2 = Upscale(d_ch*2, d_ch*1) self.upscale1_3 = Upscale(d_ch*2, d_ch*1) self.out1 = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.upscalem1 = Upscale(in_ch, d_ch) self.upscalem2 = Upscale(d_ch, d_ch//2) self.upscalem3 = Upscale(d_ch//2, d_ch//2) self.outm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) def get_weights_ex(self, stage): # Call internal get_weights in order to initialize inner logic self.get_weights() weights = self.upscalem1.get_weights() \ + self.upscalem2.get_weights() \ + self.upscalem3.get_weights() \ + self.outm.get_weights() if stage >= 0: weights += self.upscale0_1.get_weights() \ + self.upscale0_2.get_weights() \ + self.upscale0_3.get_weights() \ + self.out0.get_weights() if stage >= 1: weights += self.upscale1_1.get_weights() \ + self.upscale1_2.get_weights() \ + self.upscale1_3.get_weights() \ + self.out1.get_weights() return weights def forward(self, inp, stage=0): z = inp x0_1 = self.upscale0_1 (z) x0_2 = self.upscale0_2 (x0_1) x0_3 = self.upscale0_3 (x0_2) x = tf.nn.tanh(self.out0(x0_3)) if stage >= 1: x1_1 = self.upscale1_1 (z) x1_2 = self.upscale1_2 ( tf.concat([x0_1, x1_1],-1) ) x1_3 = self.upscale1_3 ( tf.concat([x0_2, x1_2],-1) ) x1 = tf.nn.tanh(self.out1(tf.concat([x0_3, x1_3],-1))) x = x+x1 y = self.upscalem1 (z) y = self.upscalem2 (y) y = self.upscalem3 (y) return x / 2 + 0.5, \ tf.nn.sigmoid(self.outm(y)) device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.stage = stage = self.options['stage'] self.start_stage_iter = self.options.get('start_stage_iter', 0) self.target_stage_iter = self.options.get('target_stage_iter', 0) resolution = self.resolution = 192 ae_dims = 128 e_dims = 128 d_dims = 64 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4 models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_nc = 3 output_nc = 3 bgr_shape = (resolution, resolution, output_nc) mask_shape = (resolution, resolution, 1) lowest_dense_res = resolution // 16 self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU self.warped_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.warped_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_srcm = tf.placeholder (tf.float32, (None,)+mask_shape) self.target_dstm = tf.placeholder (tf.float32, (None,)+mask_shape) # Initializing model classes with tf.device (models_opt_device): self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, name='encoder') encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1] self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') inter_out_ch = self.inter.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1] self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], [self.decoder_src, 'decoder_src.npy'], [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: self.src_dst_all_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() \ + self.decoder_src.get_weights_ex(stage) \ + self.decoder_dst.get_weights_ex(stage) # Initialize optimizers self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt.initialize_variables(self.src_dst_all_weights, vars_on_cpu=optimizer_vars_on_cpu ) self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] gpu_pred_src_dst_list = [] gpu_pred_src_srcm_list = [] gpu_pred_dst_dstm_list = [] gpu_pred_src_dstm_list = [] gpu_src_losses = [] gpu_dst_losses = [] gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] # process model tensors gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code, stage=stage) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code, stage=stage) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) gpu_target_srcmasked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( tf.square( gpu_target_srcm - gpu_pred_src_srcm ), axis=[1,2,3] ) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss_gvs += [ nn.gradients ( gpu_src_dst_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device (models_opt_device): if gpu_count == 1: pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = gpu_pred_src_dstm_list[0] src_loss = gpu_src_losses[0] dst_loss = gpu_dst_losses[0] src_dst_loss_gv = gpu_src_dst_loss_gvs[0] else: pred_src_src = tf.concat(gpu_pred_src_src_list, 0) pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0) pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0) src_loss = nn.average_tensor_list(gpu_src_losses) dst_loss = nn.average_tensor_list(gpu_dst_losses) src_dst_loss_gv = nn.average_gv_list (gpu_src_dst_loss_gvs) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv) # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, \ warped_dst, target_dst, target_dstm): s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, self.warped_dst :warped_dst, self.target_dst :target_dst, self.target_dstm:target_dstm, }) s = np.mean(s) d = np.mean(d) return s, d self.src_dst_train = src_dst_train def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code, stage=stage) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code, stage=stage) def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): do_init = self.is_first_run() if self.pretrain_just_disabled: if model == self.inter: do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: model.init_weights() # initializing sample generators if self.is_training: t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count - src_generators_count self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution, }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution } ], generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution} ], generators_count=dst_generators_count ) ]) self.last_samples = None