class XSegModel(ModelBase): def __init__(self, *args, **kwargs): super().__init__(*args, force_model_class_name='XSeg', **kwargs) #override def on_initialize_options(self): ask_override = self.ask_override() if not self.is_first_run() and ask_override: if io.input_bool( f"Restart training?", False, help_message= "Reset model weights and start training from scratch."): self.set_iter(0) default_face_type = self.options[ 'face_type'] = self.load_or_def_option('face_type', 'wf') default_pretrain = self.options['pretrain'] = self.load_or_def_option( 'pretrain', False) if self.is_first_run(): self.options['face_type'] = io.input_str( "Face type", default_face_type, ['h', 'mf', 'f', 'wf', 'head'], help_message= "Half / mid face / full face / whole face / head. Choose the same as your deepfake model." ).lower() if self.is_first_run() or ask_override: self.ask_batch_size(4, range=[2, 16]) self.options['pretrain'] = io.input_bool("Enable pretraining mode", default_pretrain) if not self.is_exporting and ( self.options['pretrain'] and self.get_pretraining_data_path() is None): raise Exception("pretraining_data_path is not defined") self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) #override def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if self.is_exporting or (len( device_config.devices) != 0 and not self.is_debug()) else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 256 self.face_type = { 'h': FaceType.HALF, 'mf': FaceType.MID_FULL, 'f': FaceType.FULL, 'wf': FaceType.WHOLE_FACE, 'head': FaceType.HEAD }[self.options['face_type']] place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name bgr_shape = nn.get4Dshape(resolution, resolution, 3) mask_shape = nn.get4Dshape(resolution, resolution, 1) # Initializing model classes self.model = XSegNet(name='XSeg', resolution=resolution, load_weights=not self.is_first_run(), weights_file_root=self.get_model_root_path(), training=True, place_model_on_cpu=place_model_on_cpu, optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), data_format=nn.data_format) self.pretrain = self.options['pretrain'] if self.pretrain_just_disabled: self.set_iter(0) if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_list = [] gpu_losses = [] gpu_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_input_t = self.model.input_t[batch_slice, :, :, :] gpu_target_t = self.model.target_t[ batch_slice, :, :, :] # process model tensors gpu_pred_logits_t, gpu_pred_t = self.model.flow( gpu_input_t, pretrain=self.pretrain) gpu_pred_list.append(gpu_pred_t) if self.pretrain: # Structural loss gpu_loss = tf.reduce_mean( 5 * nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution / 11.6)), axis=[1]) gpu_loss += tf.reduce_mean( 5 * nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution / 23.2)), axis=[1]) # Pixel loss gpu_loss += tf.reduce_mean( 10 * tf.square(gpu_target_t - gpu_pred_t), axis=[1, 2, 3]) else: gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1, 2, 3]) gpu_losses += [gpu_loss] gpu_loss_gvs += [ nn.gradients(gpu_loss, self.model.get_weights()) ] # Average losses and gradients, and create optimizer update ops #with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok with tf.device(models_opt_device): pred = tf.concat(gpu_pred_list, 0) loss = tf.concat(gpu_losses, 0) loss_gv_op = self.model.opt.get_update_op( nn.average_gv_list(gpu_loss_gvs)) # Initializing training and view functions if self.pretrain: def train(input_np, target_np): l, _ = nn.tf_sess.run( [loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l else: def train(input_np, target_np): l, _ = nn.tf_sess.run( [loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l self.train = train def view(input_np): return nn.tf_sess.run([pred], feed_dict={self.model.input_t: input_np}) self.view = view # initializing sample generators cpu_count = min(multiprocessing.cpu_count(), 8) src_dst_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if self.pretrain: pretrain_gen = SampleGeneratorFace( self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=True), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.BGR, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': True, 'transform': True, 'channel_type': SampleProcessor.ChannelType.G, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], uniform_yaw_distribution=False, generators_count=cpu_count) self.set_training_data_generators([pretrain_gen]) else: srcdst_generator = SampleGeneratorFaceXSeg( [self.training_data_src_path, self.training_data_dst_path], debug=self.is_debug(), batch_size=self.get_batch_size(), resolution=resolution, face_type=self.face_type, generators_count=src_dst_generators_count, data_format=nn.data_format) src_generator = SampleGeneratorFace( self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count, raise_on_no_data=False) dst_generator = SampleGeneratorFace( self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count, raise_on_no_data=False) self.set_training_data_generators( [srcdst_generator, src_generator, dst_generator]) #override def get_model_filename_list(self): return self.model.model_filename_list #override def onSave(self): self.model.save_weights() #override def onTrainOneIter(self): image_np, target_np = self.generate_next_samples()[0] loss = self.train(image_np, target_np) return (('loss', np.mean(loss)), ) #override def onGetPreview(self, samples, for_history=False): n_samples = min(4, self.get_batch_size(), 800 // self.resolution) if self.pretrain: srcdst_samples, = samples image_np, mask_np = srcdst_samples else: srcdst_samples, src_samples, dst_samples = samples image_np, mask_np = srcdst_samples I, M, IM, = [ np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np, mask_np] + self.view(image_np)) ] M, IM, = [np.repeat(x, (3, ), -1) for x in [M, IM]] green_bg = np.tile( np.array([0, 1, 0], dtype=np.float32)[None, None, ...], (self.resolution, self.resolution, 1)) result = [] st = [] for i in range(n_samples): if self.pretrain: ar = I[i], IM[i] else: ar = I[i] * M[i] + 0.5 * I[i] * (1 - M[i]) + 0.5 * green_bg * ( 1 - M[i]), IM[i], I[i] * IM[i] + 0.5 * I[i] * ( 1 - IM[i]) + 0.5 * green_bg * (1 - IM[i]) st.append(np.concatenate(ar, axis=1)) result += [ ('XSeg training faces', np.concatenate(st, axis=0)), ] if not self.pretrain and len(src_samples) != 0: src_np, = src_samples D, DM, = [ np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view(src_np)) ] DM, = [np.repeat(x, (3, ), -1) for x in [DM]] st = [] for i in range(n_samples): ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * ( 1 - DM[i]) + 0.5 * green_bg * (1 - DM[i]) st.append(np.concatenate(ar, axis=1)) result += [ ('XSeg src faces', np.concatenate(st, axis=0)), ] if not self.pretrain and len(dst_samples) != 0: dst_np, = dst_samples D, DM, = [ np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view(dst_np)) ] DM, = [np.repeat(x, (3, ), -1) for x in [DM]] st = [] for i in range(n_samples): ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * ( 1 - DM[i]) + 0.5 * green_bg * (1 - DM[i]) st.append(np.concatenate(ar, axis=1)) result += [ ('XSeg dst faces', np.concatenate(st, axis=0)), ] return result def export_dfm(self): output_path = self.get_strpath_storage_for_file(f'model.onnx') io.log_info(f'Dumping .onnx to {output_path}') tf = nn.tf with tf.device(nn.tf_default_device_name): input_t = tf.placeholder( nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') input_t = tf.transpose(input_t, (0, 3, 1, 2)) _, pred_t = self.model.flow(input_t) pred_t = tf.transpose(pred_t, (0, 2, 3, 1)) tf.identity(pred_t, name='out_mask') output_graph_def = tf.graph_util.convert_variables_to_constants( nn.tf_sess, tf.get_default_graph().as_graph_def(), ['out_mask']) import tf2onnx with tf.device("/CPU:0"): model_proto, _ = tf2onnx.convert._convert_common( output_graph_def, name='XSeg', input_names=['in_face:0'], output_names=['out_mask:0'], opset=13, output_path=output_path)
class XSegModel(ModelBase): def __init__(self, *args, **kwargs): super().__init__(*args, force_model_class_name='XSeg', **kwargs) #override def on_initialize_options(self): self.set_batch_size(4) ask_override = self.ask_override() if not self.is_first_run() and ask_override: if io.input_bool( f"Restart training?", False, help_message= "Reset model weights and start training from scratch."): self.set_iter(0) default_face_type = self.options[ 'face_type'] = self.load_or_def_option('face_type', 'wf') if self.is_first_run(): self.options['face_type'] = io.input_str( "Face type", default_face_type, ['h', 'mf', 'f', 'wf', 'head'], help_message= "Half / mid face / full face / whole face / head. Choose the same as your deepfake model." ).lower() #override def on_initialize(self): device_config = nn.getCurrentDeviceConfig() self.model_data_format = "NCHW" if len( device_config.devices) != 0 and not self.is_debug() else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf device_config = nn.getCurrentDeviceConfig() devices = device_config.devices self.resolution = resolution = 256 self.face_type = { 'h': FaceType.HALF, 'mf': FaceType.MID_FULL, 'f': FaceType.FULL, 'wf': FaceType.WHOLE_FACE, 'head': FaceType.HEAD }[self.options['face_type']] place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' bgr_shape = nn.get4Dshape(resolution, resolution, 3) mask_shape = nn.get4Dshape(resolution, resolution, 1) # Initializing model classes self.model = XSegNet(name='XSeg', resolution=resolution, load_weights=not self.is_first_run(), weights_file_root=self.get_model_root_path(), training=True, place_model_on_cpu=place_model_on_cpu, optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), data_format=nn.data_format) if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices)) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size(gpu_count * bs_per_gpu) # Compute losses per GPU gpu_pred_list = [] gpu_losses = [] gpu_loss_gvs = [] for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice(gpu_id * bs_per_gpu, (gpu_id + 1) * bs_per_gpu) gpu_input_t = self.model.input_t[batch_slice, :, :, :] gpu_target_t = self.model.target_t[ batch_slice, :, :, :] # process model tensors gpu_pred_logits_t, gpu_pred_t = self.model.flow( gpu_input_t) gpu_pred_list.append(gpu_pred_t) gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1, 2, 3]) gpu_losses += [gpu_loss] gpu_loss_gvs += [ nn.gradients(gpu_loss, self.model.get_weights()) ] # Average losses and gradients, and create optimizer update ops with tf.device(models_opt_device): pred = nn.concat(gpu_pred_list, 0) loss = tf.reduce_mean(gpu_losses) loss_gv_op = self.model.opt.get_update_op( nn.average_gv_list(gpu_loss_gvs)) # Initializing training and view functions def train(input_np, target_np): l, _ = nn.tf_sess.run([loss, loss_gv_op], feed_dict={ self.model.input_t: input_np, self.model.target_t: target_np }) return l self.train = train def view(input_np): return nn.tf_sess.run([pred], feed_dict={self.model.input_t: input_np}) self.view = view # initializing sample generators cpu_count = min(multiprocessing.cpu_count(), 8) src_dst_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 srcdst_generator = SampleGeneratorFaceXSeg( [self.training_data_src_path, self.training_data_dst_path], debug=self.is_debug(), batch_size=self.get_batch_size(), resolution=resolution, face_type=self.face_type, generators_count=src_dst_generators_count, data_format=nn.data_format) src_generator = SampleGeneratorFace( self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=src_generators_count, raise_on_no_data=False) dst_generator = SampleGeneratorFace( self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options( random_flip=False), output_sample_types=[ { 'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp': False, 'transform': False, 'channel_type': SampleProcessor.ChannelType.BGR, 'border_replicate': False, 'face_type': self.face_type, 'data_format': nn.data_format, 'resolution': resolution }, ], generators_count=dst_generators_count, raise_on_no_data=False) self.set_training_data_generators( [srcdst_generator, src_generator, dst_generator]) #override def get_model_filename_list(self): return self.model.model_filename_list #override def onSave(self): self.model.save_weights() #override def onTrainOneIter(self): image_np, mask_np = self.generate_next_samples()[0] loss = self.train(image_np, mask_np) return (('loss', loss), ) #override def onGetPreview(self, samples): n_samples = min(4, self.get_batch_size(), 800 // self.resolution) srcdst_samples, src_samples, dst_samples = samples image_np, mask_np = srcdst_samples I, M, IM, = [ np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np, mask_np] + self.view(image_np)) ] M, IM, = [np.repeat(x, (3, ), -1) for x in [M, IM]] green_bg = np.tile( np.array([0, 1, 0], dtype=np.float32)[None, None, ...], (self.resolution, self.resolution, 1)) result = [] st = [] for i in range(n_samples): ar = I[i] * M[i] + 0.5 * I[i] * (1 - M[i]) + 0.5 * green_bg * ( 1 - M[i]), IM[i], I[i] * IM[i] + 0.5 * I[i] * ( 1 - IM[i]) + 0.5 * green_bg * (1 - IM[i]) st.append(np.concatenate(ar, axis=1)) result += [ ('XSeg training faces', np.concatenate(st, axis=0)), ] if len(src_samples) != 0: src_np, = src_samples D, DM, = [ np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view(src_np)) ] DM, = [np.repeat(x, (3, ), -1) for x in [DM]] st = [] for i in range(n_samples): ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * ( 1 - DM[i]) + 0.5 * green_bg * (1 - DM[i]) st.append(np.concatenate(ar, axis=1)) result += [ ('XSeg src faces', np.concatenate(st, axis=0)), ] if len(dst_samples) != 0: dst_np, = dst_samples D, DM, = [ np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view(dst_np)) ] DM, = [np.repeat(x, (3, ), -1) for x in [DM]] st = [] for i in range(n_samples): ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * ( 1 - DM[i]) + 0.5 * green_bg * (1 - DM[i]) st.append(np.concatenate(ar, axis=1)) result += [ ('XSeg dst faces', np.concatenate(st, axis=0)), ] return result