예제 #1
0
파일: Model.py 프로젝트: rumax/DeepFaceLab
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        nn.initialize(data_format="NHWC")
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256  #self.options['resolution']
        #self.face_type = {'h'  : FaceType.HALF,
        #                  'mf' : FaceType.MID_FULL,
        #                  'f'  : FaceType.FULL,
        #                  'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ]
        self.face_type = FaceType.FULL

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = TernausNet(f'{self.model_name}_FANSeg',
                                resolution,
                                FaceType.toString(self.face_type),
                                load_weights=not self.is_first_run(),
                                weights_file_root=self.get_model_root_path(),
                                training=True,
                                place_model_on_cpu=place_model_on_cpu)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.net(
                        [gpu_input_t])
                    gpu_pred_list.append(gpu_pred_t)

                    gpu_loss = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=gpu_target_t, logits=gpu_pred_logits_t),
                        axis=[1, 2, 3])
                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.tf_gradients(gpu_loss, self.model.net_weights)
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred = nn.tf_concat(gpu_pred_list, 0)
                loss = tf.reduce_mean(gpu_losses)

                loss_gv_op = self.model.opt.get_update_op(
                    nn.tf_average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            def train(input_np, target_np):
                l, _ = nn.tf_sess.run([loss, loss_gv_op],
                                      feed_dict={
                                          self.model.input_t: input_np,
                                          self.model.target_t: target_np
                                      })
                return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            training_data_src_path = self.training_data_src_path
            training_data_dst_path = self.training_data_dst_path

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2
            src_generators_count = int(src_generators_count * 1.5)

            src_generator = SampleGeneratorFace(
                training_data_src_path,
                random_ct_samples_path=training_data_src_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=True),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'ct_mode': 'lct',
                        'warp': True,
                        'transform': True,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'face_type': self.face_type,
                        'motion_blur': (25, 5),
                        'gaussian_blur': (25, 5),
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_MASK,
                        'warp': True,
                        'transform': True,
                        'channel_type': SampleProcessor.ChannelType.G,
                        'face_mask_type':
                        SampleProcessor.FaceMaskType.FULL_FACE,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=src_generators_count)

            dst_generator = SampleGeneratorFace(
                training_data_dst_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=True),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': True,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'face_type': self.face_type,
                        'motion_blur': (25, 5),
                        'gaussian_blur': (25, 5),
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=dst_generators_count,
                raise_on_no_data=False)
            if not dst_generator.is_initialized():
                io.log_info(
                    f"\nTo view the model on unseen faces, place any aligned faces in {training_data_dst_path}.\n"
                )

            self.set_training_data_generators([src_generator, dst_generator])
예제 #2
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug()  else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        conv_kernel_initializer = nn.initializers.ca()

        class Downscale(nn.ModelBase):
            def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                self.dilations = dilations
                self.subpixel = subpixel
                self.use_activator = use_activator
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs ):
                self.conv1 = nn.Conv2D( self.in_ch,
                                          self.out_ch // (4 if self.subpixel else 1),
                                          kernel_size=self.kernel_size,
                                          strides=1 if self.subpixel else 2,
                                          padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer )

            def forward(self, x):
                x = self.conv1(x)

                if self.subpixel:
                    x = nn.tf_space_to_depth(x, 2)

                if self.use_activator:
                    x = nn.tf_gelu(x)
                return x

            def get_out_ch(self):
                return (self.out_ch // 4) * 4

        class DownscaleBlock(nn.ModelBase):
            def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
                self.downs = []

                last_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch*( min(2**i, 8)  )
                    self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
                    last_ch = self.downs[-1].get_out_ch()

            def forward(self, inp):
                x = inp
                for down in self.downs:
                    x = down(x)
                return x

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = nn.tf_gelu(x)
                x = nn.tf_depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
                self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                x = self.conv1(inp)
                x = nn.tf_gelu(x)
                x = self.conv2(x)
                x = inp + x
                x = nn.tf_gelu(x)
                return x

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch):
                self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
            def forward(self, inp):
                return nn.tf_flatten(self.down1(inp))

        class Inter(nn.ModelBase):
            def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch, **kwargs):
                self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch
                super().__init__(**kwargs)

            def on_build(self):
                in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch

                self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
                self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, maxout_features=4, kernel_initializer=tf.initializers.orthogonal )
                self.upscale1 = Upscale(ae_out_ch, d_ch*8)
                self.res1 = ResidualBlock(d_ch*8)

            def forward(self, inp):
                x = self.dense1(inp)
                x = self.dense2(x)
                x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
                x = self.upscale1(x)
                x = self.res1(x)
                return x

            def get_out_ch(self):
                return self.ae_out_ch

        class Decoder(nn.ModelBase):
            def on_build(self, in_ch, d_ch):
                self.upscale1 = Upscale(in_ch, d_ch*4)
                self.res1     = ResidualBlock(d_ch*4)
                self.upscale2 = Upscale(d_ch*4, d_ch*2)
                self.res2     = ResidualBlock(d_ch*2)
                self.upscale3 = Upscale(d_ch*2, d_ch*1)
                self.res3     = ResidualBlock(d_ch*1)

                self.upscalem1 = Upscale(in_ch, d_ch)
                self.upscalem2 = Upscale(d_ch, d_ch//2)
                self.upscalem3 = Upscale(d_ch//2, d_ch//2)

                self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
                self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                z = inp
                x = self.upscale1 (z)
                x = self.res1     (x)
                x = self.upscale2 (x)
                x = self.res2     (x)
                x = self.upscale3 (x)
                x = self.res3     (x)

                y = self.upscalem1 (z)
                y = self.upscalem2 (y)
                y = self.upscalem3 (y)

                return tf.nn.sigmoid(self.out_conv(x)), \
                       tf.nn.sigmoid(self.out_convm(y))

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        resolution = self.resolution = 96
        ae_dims = 128
        e_dims = 128
        d_dims = 64
        self.pretrain = False
        self.pretrain_just_disabled = False

        masked_training = True

        models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 2 for dev in devices])
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_ch = 3
        output_ch = 3
        bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        lowest_dense_res = resolution // 16

        self.model_filename_list = []


        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
            self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)

            self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
            self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)

            self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
            self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)

        # Initializing model classes
        with tf.device (models_opt_device):
            self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
            encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))

            self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter')
            inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))

            self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src')
            self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst')

            self.model_filename_list += [ [self.encoder,     'encoder.npy'    ],
                                          [self.inter,       'inter.npy'      ],
                                          [self.decoder_src, 'decoder_src.npy'],
                                          [self.decoder_dst, 'decoder_dst.npy']  ]

            if self.is_training:
                self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()

                # Initialize optimizers
                self.src_dst_opt = nn.TFRMSpropOptimizer(lr=2e-4, lr_dropout=0.3, name='src_dst_opt')
                self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu )
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, 4 // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_src_dst_loss_gvs = []
            
            for gpu_id in range(gpu_count):
                with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
                    batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_warped_src   = self.warped_src [batch_slice,:,:,:]
                        gpu_warped_dst   = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src   = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst   = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm  = self.target_srcm[batch_slice,:,:,:]
                        gpu_target_dstm  = self.target_dstm[batch_slice,:,:,:]

                    # process model tensors
                    gpu_src_code     = self.inter(self.encoder(gpu_warped_src))
                    gpu_dst_code     = self.inter(self.encoder(gpu_warped_dst))
                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm,  max(1, resolution // 32) )
                    gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm,  max(1, resolution // 32) )

                    gpu_target_dst_masked      = gpu_target_dst*gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_target_src_masked_opt  = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_src_loss =  tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )

                    gpu_dst_loss  = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square(  gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss
                    gpu_src_dst_loss_gvs += [ nn.tf_gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]


            # Average losses and gradients, and create optimizer update ops
            with tf.device (models_opt_device):
                pred_src_src  = nn.tf_concat(gpu_pred_src_src_list, 0)
                pred_dst_dst  = nn.tf_concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst  = nn.tf_concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)

                src_loss = nn.tf_average_tensor_list(gpu_src_losses)
                dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
                src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv)

            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, \
                              warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm:target_srcm,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm:target_dstm,
                                                       })
                s = np.mean(s)
                d = np.mean(d)
                return s, d
            self.src_dst_train = src_dst_train

            def AE_view(warped_src, warped_dst):
                return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                            feed_dict={self.warped_src:warped_src,
                                                    self.warped_dst:warped_dst})

            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code     = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

            def AE_merge( warped_dst):

                return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if model == self.inter:
                    do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )

            if do_init and self.pretrained_model_path is not None:
                pretrained_filepath = self.pretrained_model_path / filename
                if pretrained_filepath.exists():
                    do_init = not model.load_weights(pretrained_filepath)

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_FULL

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
                        output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution, },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),        'data_format':nn.data_format, 'resolution': resolution, },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M),          'data_format':nn.data_format, 'resolution': resolution } ],
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
                        output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),        'data_format':nn.data_format, 'resolution': resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M),          'data_format':nn.data_format, 'resolution': resolution} ],
                        generators_count=dst_generators_count )
                             ])

            self.last_samples = None
예제 #3
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(floatx="float16" if self.options['use_float16'] else "float32",
                      data_format=self.model_data_format)
        tf = nn.tf

        conv_kernel_initializer = nn.initializers.ca()

        class Downscale(nn.ModelBase):
            def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                self.dilations = dilations
                self.subpixel = subpixel
                self.use_activator = use_activator
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs ):
                self.conv1 = nn.Conv2D( self.in_ch,
                                        self.out_ch // (4 if self.subpixel else 1),
                                        kernel_size=self.kernel_size,
                                        strides=1 if self.subpixel else 2,
                                        padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                if self.subpixel:
                    x = nn.tf_space_to_depth(x, 2)
                if self.use_activator:
                    x = tf.nn.leaky_relu(x, 0.1)
                return x

            def get_out_ch(self):
                return (self.out_ch // 4) * 4

        class DownscaleBlock(nn.ModelBase):
            def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
                self.downs = []

                last_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch*( min(2**i, 8)  )
                    self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
                    last_ch = self.downs[-1].get_out_ch()

            def forward(self, inp):
                x = inp
                for down in self.downs:
                    x = down(x)
                return x

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                x = nn.tf_depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
                self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.2)
                x = self.conv2(x)
                x = tf.nn.leaky_relu(inp + x, 0.2)
                return x

        class UpdownResidualBlock(nn.ModelBase):
            def on_build(self, ch, inner_ch, kernel_size=3 ):
                self.up   = Upscale (ch, inner_ch, kernel_size=kernel_size)
                self.res  = ResidualBlock (inner_ch, kernel_size=kernel_size)
                self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)

            def forward(self, inp):
                x = self.up(inp)
                x = upx = self.res(x)
                x = self.down(x)
                x = x + inp
                x = tf.nn.leaky_relu(x, 0.2)
                return x, upx

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch, is_hd):
                self.is_hd=is_hd
                if self.is_hd:
                    self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
                    self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
                    self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
                    self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
                else:
                    self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)

            def forward(self, inp):
                if self.is_hd:
                    x = tf.concat([ nn.tf_flatten(self.down1(inp)),
                                    nn.tf_flatten(self.down2(inp)),
                                    nn.tf_flatten(self.down3(inp)),
                                    nn.tf_flatten(self.down4(inp)) ], -1 )
                else:
                    x = nn.tf_flatten(self.down1(inp))
                return x

        class Inter(nn.ModelBase):
            def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs):
                self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch
                super().__init__(**kwargs)

            def on_build(self):
                in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch

                self.dense1 = nn.Dense( in_ch, ae_ch )
                self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
                self.upscale1 = Upscale(ae_out_ch, ae_out_ch)

            def forward(self, inp):
                x = self.dense1(inp)
                x = self.dense2(x)
                x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
                x = self.upscale1(x)
                return x

            def get_out_ch(self):
                return self.ae_out_ch

        class Decoder(nn.ModelBase):
            def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
                self.is_hd = is_hd

                self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
                self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
                self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)

                if is_hd:
                    self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
                    self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
                    self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
                    self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
                else:
                    self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
                    self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
                    self.res2 = ResidualBlock(d_ch*2, kernel_size=3)

                self.out_conv  = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)

                self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
                self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
                self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
                self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def get_weights_ex(self, include_mask):
                # Call internal get_weights in order to initialize inner logic
                self.get_weights()

                weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \
                          + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights()

                if include_mask:
                    weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \
                               + self.out_convm.get_weights()
                return weights


            def forward(self, inp):
                z = inp

                if self.is_hd:
                    x, upx = self.res0(z)
                    x = self.upscale0(x)
                    x = tf.nn.leaky_relu(x + upx, 0.2)
                    x, upx = self.res1(x)

                    x = self.upscale1(x)
                    x = tf.nn.leaky_relu(x + upx, 0.2)
                    x, upx = self.res2(x)

                    x = self.upscale2(x)
                    x = tf.nn.leaky_relu(x + upx, 0.2)
                    x, upx = self.res3(x)
                else:
                    x = self.upscale0(z)
                    x = self.res0(x)
                    x = self.upscale1(x)
                    x = self.res1(x)
                    x = self.upscale2(x)
                    x = self.res2(x)

                m = self.upscalem0(z)
                m = self.upscalem1(m)
                m = self.upscalem2(m)

                return tf.nn.sigmoid(self.out_conv(x)), \
                       tf.nn.sigmoid(self.out_convm(m))

        class CodeDiscriminator(nn.ModelBase):
            def on_build(self, in_ch, code_res, ch=256):
                n_downscales = 1 + code_res // 8

                self.convs = []
                prev_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch * min( (2**i), 8 )
                    self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=4 if i == 0 else 3, strides=2, padding='SAME', kernel_initializer=conv_kernel_initializer) )
                    prev_ch = cur_ch

                self.out_conv =  nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                for conv in self.convs:
                    x = tf.nn.leaky_relu( conv(x), 0.1 )
                return self.out_conv(x)

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = self.options['resolution']
        learn_mask = self.options['learn_mask']
        archi = self.options['archi']
        ae_dims = self.options['ae_dims']
        e_dims = self.options['e_dims']
        d_dims = self.options['d_dims']
        d_mask_dims = self.options['d_mask_dims']
        self.pretrain = self.options['pretrain']
        if self.pretrain_just_disabled:
            self.set_iter(0)

        self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0

        masked_training = True

        models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_ch = 3
        output_ch = 3
        bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        lowest_dense_res = resolution // 16

        self.model_filename_list = []


        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
            self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)

            self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
            self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)

            self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
            self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)

        # Initializing model classes
        with tf.device (models_opt_device):
            if 'df' in archi:
                self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
                encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))

                self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
                inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))

                self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_src')
                self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_dst')

                self.model_filename_list += [ [self.encoder,     'encoder.npy'    ],
                                              [self.inter,       'inter.npy'      ],
                                              [self.decoder_src, 'decoder_src.npy'],
                                              [self.decoder_dst, 'decoder_dst.npy']  ]

                if self.is_training:
                    if self.options['true_face_power'] != 0:
                        self.code_discriminator = CodeDiscriminator(ae_dims, code_res=lowest_dense_res*2, name='dis' )
                        self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]

            elif 'liae' in archi:
                self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
                encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))

                self.inter_AB = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
                self.inter_B  = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')

                inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
                inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
                inters_out_ch = inter_AB_out_ch+inter_B_out_ch
                self.decoder = Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder')

                self.model_filename_list += [ [self.encoder,  'encoder.npy'],
                                              [self.inter_AB, 'inter_AB.npy'],
                                              [self.inter_B , 'inter_B.npy'],
                                              [self.decoder , 'decoder.npy'] ]

            if self.is_training:
                if gan_power != 0:
                    self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=output_ch, base_ch=512, name="D_src")
                    self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=output_ch, base_ch=512, name="D_dst")
                    self.model_filename_list += [ [self.D_src, 'D_src.npy'] ]
                    self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ]

                # Initialize optimizers
                lr=5e-5
                lr_dropout = 0.3 if self.options['lr_dropout'] else 1.0
                clipnorm = 1.0 if self.options['clipgrad'] else 0.0
                self.src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
                if 'df' in archi:
                    self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask)

                elif 'liae' in archi:
                    self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights_ex(learn_mask)

                self.src_dst_opt.initialize_variables (self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)

                if self.options['true_face_power'] != 0:
                    self.D_code_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
                    self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
                    self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]

                if gan_power != 0:
                    self.D_src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
                    self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
                    self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)


            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_G_loss_gvs = []
            gpu_D_code_loss_gvs = []
            gpu_D_src_dst_loss_gvs = []
            for gpu_id in range(gpu_count):
                with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                        gpu_warped_src   = self.warped_src [batch_slice,:,:,:]
                        gpu_warped_dst   = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src   = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst   = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm  = self.target_srcm[batch_slice,:,:,:]
                        gpu_target_dstm  = self.target_dstm[batch_slice,:,:,:]

                    # process model tensors
                    if 'df' in archi:
                        gpu_src_code     = self.inter(self.encoder(gpu_warped_src))
                        gpu_dst_code     = self.inter(self.encoder(gpu_warped_dst))
                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)

                    elif 'liae' in archi:
                        gpu_src_code = self.encoder (gpu_warped_src)
                        gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
                        gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis  )
                        gpu_dst_code = self.encoder (gpu_warped_dst)
                        gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                        gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                        gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
                        gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )

                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm,  max(1, resolution // 32) )
                    gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm,  max(1, resolution // 32) )

                    gpu_target_dst_masked      = gpu_target_dst*gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_target_src_masked_opt  = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt  = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_src_loss =  tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
                    if learn_mask:
                        gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )

                    face_style_power = self.options['face_style_power'] / 100.0
                    if face_style_power != 0 and not self.pretrain:
                        gpu_src_loss += nn.tf_style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)

                    bg_style_power = self.options['bg_style_power'] / 100.0
                    if bg_style_power != 0 and not self.pretrain:
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )

                    gpu_dst_loss  = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square(  gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
                    if learn_mask:
                        gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss

                    def DLoss(labels,logits):
                        return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])

                    if self.options['true_face_power'] != 0:
                        gpu_src_code_d = self.code_discriminator( gpu_src_code )
                        gpu_src_code_d_ones  = tf.ones_like (gpu_src_code_d)
                        gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d)
                        gpu_dst_code_d = self.code_discriminator( gpu_dst_code )
                        gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d)

                        gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)

                        gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
                                           DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5

                        gpu_D_code_loss_gvs += [ nn.tf_gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]

                    if gan_power != 0:
                        gpu_pred_src_src_d       = self.D_src(gpu_pred_src_src_masked_opt)
                        gpu_pred_src_src_d_ones  = tf.ones_like (gpu_pred_src_src_d)
                        gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d)
                        gpu_target_src_d         = self.D_src(gpu_target_src_masked_opt)
                        gpu_target_src_d_ones    = tf.ones_like(gpu_target_src_d)
                        gpu_pred_dst_dst_d       = self.D_dst(gpu_pred_dst_dst_masked_opt)
                        gpu_pred_dst_dst_d_ones  = tf.ones_like (gpu_pred_dst_dst_d)
                        gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d)
                        gpu_target_dst_d         = self.D_dst(gpu_target_dst_masked_opt)
                        gpu_target_dst_d_ones    = tf.ones_like(gpu_target_dst_d)

                        gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones   , gpu_target_src_d) + \
                                              DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \
                                             (DLoss(gpu_target_dst_d_ones   , gpu_target_dst_d) + \
                                              DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5

                        gpu_D_src_dst_loss_gvs += [ nn.tf_gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ]

                        gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d))


                    gpu_G_loss_gvs += [ nn.tf_gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]


            # Average losses and gradients, and create optimizer update ops
            with tf.device (models_opt_device):
                pred_src_src  = nn.tf_concat(gpu_pred_src_src_list, 0)
                pred_dst_dst  = nn.tf_concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst  = nn.tf_concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)
                src_loss = nn.tf_average_tensor_list(gpu_src_losses)
                dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.tf_average_gv_list (gpu_G_loss_gvs))

                if self.options['true_face_power'] != 0:
                    D_loss_gv_op = self.D_code_opt.get_update_op (nn.tf_average_gv_list(gpu_D_code_loss_gvs))

                if gan_power != 0:
                    src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.tf_average_gv_list(gpu_D_src_dst_loss_gvs) )


            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, \
                              warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm:target_srcm,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm:target_dstm,
                                                       })
                s = np.mean(s)
                d = np.mean(d)
                return s, d
            self.src_dst_train = src_dst_train

            if self.options['true_face_power'] != 0:
                def D_train(warped_src, warped_dst):
                    nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst})
                self.D_train = D_train

            if gan_power != 0:
                def D_src_dst_train(warped_src, target_src, target_srcm, \
                                    warped_dst, target_dst, target_dstm):
                    nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
                                                                           self.target_src :target_src,
                                                                           self.target_srcm:target_srcm,
                                                                           self.warped_dst :warped_dst,
                                                                           self.target_dst :target_dst,
                                                                           self.target_dstm:target_dstm})
                self.D_src_dst_train = D_src_dst_train

            if learn_mask:
                def AE_view(warped_src, warped_dst):
                    return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                             feed_dict={self.warped_src:warped_src,
                                                        self.warped_dst:warped_dst})
            else:
                def AE_view(warped_src, warped_dst):
                    return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_src_dst],
                                             feed_dict={self.warped_src:warped_src,
                                                        self.warped_dst:warped_dst})
            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                if 'df' in archi:
                    gpu_dst_code     = self.inter(self.encoder(self.warped_dst))
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

                elif 'liae' in archi:
                    gpu_dst_code = self.encoder (self.warped_dst)
                    gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                    gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                    gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
                    gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)

                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)

            if learn_mask:
                def AE_merge( warped_dst):
                    return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
            else:
                def AE_merge( warped_dst):
                    return nn.tf_sess.run ( [gpu_pred_src_dst], feed_dict={self.warped_dst:warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if 'df' in archi:
                    if model == self.inter:
                        do_init = True
                elif 'liae' in archi:
                    if model == self.inter_AB:
                        do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            t = SampleProcessor.Types
            if self.options['face_type'] == 'h':
                face_type = t.FACE_TYPE_HALF
            elif self.options['face_type'] == 'mf':
                face_type = t.FACE_TYPE_MID_FULL
            elif self.options['face_type'] == 'f':
                face_type = t.FACE_TYPE_FULL

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()

            random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None

            t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2
            if self.options['ct_mode'] != 'none':
                src_generators_count = int(src_generators_count * 1.5)

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR),      'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M),   'data_format':nn.data_format, 'resolution': resolution } ],
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR),      'data_format':nn.data_format, 'resolution': resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M),   'data_format':nn.data_format, 'resolution': resolution} ],
                        generators_count=dst_generators_count )
                             ])

            if self.pretrain_just_disabled:
                self.update_sample_for_preview(force_new=True)